Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -33,6 +33,7 @@ google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
http_client.workspace = true
inline_completion_button.workspace = true
log.workspace = true
menu.workspace = true
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
@ -42,6 +43,7 @@ schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strum.workspace = true
theme.workspace = true
tiktoken-rs.workspace = true

View file

@ -1,24 +1,24 @@
mod model;
pub mod provider;
mod rate_limiter;
mod registry;
mod request;
mod role;
pub mod settings;
use std::sync::Arc;
use anyhow::Result;
use client::Client;
use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
pub use model::*;
use project::Fs;
pub(crate) use rate_limiter::*;
pub use registry::*;
pub use request::*;
pub use role::*;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use std::{future::Future, sync::Arc};
pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
settings::init(fs, cx);
@ -46,7 +46,7 @@ pub trait LanguageModel: Send + Sync {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn use_tool(
fn use_any_tool(
&self,
request: LanguageModelRequest,
name: String,
@ -56,6 +56,22 @@ pub trait LanguageModel: Send + Sync {
) -> BoxFuture<'static, Result<serde_json::Value>>;
}
impl dyn LanguageModel {
pub fn use_tool<T: LanguageModelTool>(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> impl 'static + Future<Output = Result<T>> {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
async move {
let response = request.await?;
Ok(serde_json::from_value(response)?)
}
}
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
@ -67,9 +83,9 @@ pub trait LanguageModelProvider: 'static {
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
}
pub trait LanguageModelProviderState: 'static {

View file

@ -1,7 +1,7 @@
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
@ -36,6 +36,7 @@ pub struct AnthropicSettings {
pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
pub tool_override: Option<String>,
}
pub struct AnthropicLanguageModelProvider {
@ -98,6 +99,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
anthropic::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
},
);
}
@ -110,6 +112,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -119,7 +122,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
@ -152,7 +155,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
@ -171,6 +174,7 @@ pub struct AnthropicModel {
model: anthropic::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
pub fn count_anthropic_tokens(
@ -296,14 +300,14 @@ impl LanguageModel for AnthropicModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = request.into_anthropic(self.model.id().into());
let request = self.stream_completion(request, cx);
async move {
let future = self.request_limiter.stream(async move {
let response = request.await?;
Ok(anthropic::extract_text_from_events(response).boxed())
}
.boxed()
Ok(anthropic::extract_text_from_events(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
@ -311,7 +315,7 @@ impl LanguageModel for AnthropicModel {
input_schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
let mut request = request.into_anthropic(self.model.id().into());
let mut request = request.into_anthropic(self.model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
@ -322,25 +326,26 @@ impl LanguageModel for AnthropicModel {
}];
let response = self.request_completion(request, cx);
async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
self.request_limiter
.run(async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
})
.context("tool not used")
})
.boxed()
}
}

View file

@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
};
use anyhow::{anyhow, Context as _, Result};
use client::Client;
@ -41,6 +41,7 @@ pub struct AvailableModel {
provider: AvailableProvider,
name: String,
max_tokens: usize,
tool_override: Option<String>,
}
pub struct CloudLanguageModelProvider {
@ -56,7 +57,7 @@ struct State {
}
impl State {
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
}
@ -142,6 +143,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
}),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
@ -162,6 +164,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
id: LanguageModelId::from(model.id().to_string()),
model,
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -171,8 +174,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
self.state.read(cx).status.is_connected()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
self.state.read(cx).authenticate(cx)
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
@ -182,7 +185,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
.into()
}
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
@ -191,6 +194,7 @@ pub struct CloudLanguageModel {
id: LanguageModelId,
model: CloudModel,
client: Arc<Client>,
request_limiter: RateLimiter,
}
impl LanguageModel for CloudLanguageModel {
@ -256,7 +260,7 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let request = request.into_anthropic(model.id().into());
async move {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
@ -266,15 +270,14 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(anthropic::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
.boxed()
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = request.into_open_ai(model.id().into());
async move {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
@ -284,15 +287,14 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(open_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
.boxed()
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
async move {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
@ -302,15 +304,14 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(google_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
.boxed()
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
}
fn use_tool(
fn use_any_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
@ -321,7 +322,7 @@ impl LanguageModel for CloudLanguageModel {
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let mut request = request.into_anthropic(model.id().into());
let mut request = request.into_anthropic(model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
@ -331,32 +332,34 @@ impl LanguageModel for CloudLanguageModel {
input_schema,
}];
async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response =
serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
})
.context("tool not used")
})
.boxed()
}
CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()

View file

@ -27,7 +27,7 @@ use crate::settings::AllLanguageModelSettings;
use crate::LanguageModelProviderState;
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
};
use super::open_ai::count_open_ai_tokens;
@ -85,7 +85,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
CopilotChatModel::iter()
.map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>)
.map(|model| {
Arc::new(CopilotChatLanguageModel {
model,
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
}
@ -95,7 +100,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
.unwrap_or(false)
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
let result = if self.is_authenticated(cx) {
Ok(())
} else if let Some(copilot) = Copilot::global(cx) {
@ -121,7 +126,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let Some(copilot) = Copilot::global(cx) else {
return Task::ready(Err(anyhow::anyhow!(
"Copilot is not available. Please ensure Copilot is enabled and running and try again."
@ -145,6 +150,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
pub struct CopilotChatLanguageModel {
model: CopilotChatModel,
request_limiter: RateLimiter,
}
impl LanguageModel for CopilotChatLanguageModel {
@ -215,30 +221,35 @@ impl LanguageModel for CopilotChatLanguageModel {
return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
};
cx.spawn(|mut cx| async move {
let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(result) => {
let choice = result.choices.first();
match choice {
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))),
let request_limiter = self.request_limiter.clone();
let future = cx.spawn(|cx| async move {
let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
request_limiter.stream(async move {
let response = response.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(result) => {
let choice = result.choices.first();
match choice {
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))),
}
}
Err(err) => Some(Err(err)),
}
Err(err) => Some(Err(err)),
}
})
.boxed();
Ok(stream)
})
.boxed()
})
.boxed();
Ok(stream)
}).await
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,

View file

@ -60,7 +60,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
true
}
fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
@ -68,7 +68,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
unimplemented!()
}
fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
@ -173,7 +173,7 @@ impl LanguageModel for FakeLanguageModel {
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,

View file

@ -20,7 +20,7 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
};
const PROVIDER_ID: &str = "google";
@ -111,6 +111,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -120,7 +121,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
@ -153,7 +154,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
@ -172,6 +173,7 @@ pub struct GoogleLanguageModel {
model: google_ai::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
rate_limiter: RateLimiter,
}
impl LanguageModel for GoogleLanguageModel {
@ -243,17 +245,17 @@ impl LanguageModel for GoogleLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
}
.boxed()
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,

View file

@ -12,7 +12,7 @@ use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
@ -39,7 +39,7 @@ struct State {
}
impl State {
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
@ -80,37 +80,10 @@ impl OllamaLanguageModelProvider {
}),
}),
};
this.fetch_models(cx).detach();
this.state
.update(cx, |state, cx| state.fetch_models(cx).detach());
this
}
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
state.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
}
impl LanguageModelProviderState for OllamaLanguageModelProvider {
@ -140,6 +113,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -158,11 +132,11 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
!self.state.read(cx).available_models.is_empty()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
self.state.update(cx, |state, cx| state.fetch_models(cx))
}
}
@ -176,8 +150,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.fetch_models(cx)
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.fetch_models(cx))
}
}
@ -185,6 +159,7 @@ pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl OllamaLanguageModel {
@ -235,14 +210,14 @@ impl LanguageModel for OllamaLanguageModel {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn telemetry_id(&self) -> String {
format!("ollama/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
@ -275,10 +250,10 @@ impl LanguageModel for OllamaLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let request =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
let response = request.await?;
let future = self.request_limiter.stream(async move {
let response =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
.await?;
let stream = response
.filter_map(|response| async move {
match response {
@ -295,11 +270,12 @@ impl LanguageModel for OllamaLanguageModel {
})
.boxed();
Ok(stream)
}
.boxed()
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,

View file

@ -20,7 +20,7 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
const PROVIDER_ID: &str = "openai";
@ -112,6 +112,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -121,7 +122,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
@ -153,7 +154,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
let delete_credentials = cx.delete_credentials(&settings.api_url);
let state = self.state.clone();
@ -172,6 +173,7 @@ pub struct OpenAiLanguageModel {
model: open_ai::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl LanguageModel for OpenAiLanguageModel {
@ -226,7 +228,7 @@ impl LanguageModel for OpenAiLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
@ -237,11 +239,12 @@ impl LanguageModel for OpenAiLanguageModel {
);
let response = request.await?;
Ok(open_ai::extract_text_from_events(response).boxed())
}
.boxed()
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,

View file

@ -0,0 +1,70 @@
use anyhow::Result;
use futures::Stream;
use smol::lock::{Semaphore, SemaphoreGuardArc};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
#[derive(Clone)]
pub struct RateLimiter {
semaphore: Arc<Semaphore>,
}
pub struct RateLimitGuard<T> {
inner: T,
_guard: SemaphoreGuardArc,
}
impl<T> Stream for RateLimitGuard<T>
where
T: Stream,
{
type Item = T::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
}
}
impl RateLimiter {
pub fn new(limit: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(limit)),
}
}
pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
where
Fut: 'a + Future<Output = Result<T>>,
{
let guard = self.semaphore.acquire_arc();
async move {
let guard = guard.await;
let result = future.await?;
drop(guard);
Ok(result)
}
}
pub fn stream<'a, Fut, T>(
&self,
future: Fut,
) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item>>>
where
Fut: 'a + Future<Output = Result<T>>,
T: Stream,
{
let guard = self.semaphore.acquire_arc();
async move {
let guard = guard.await;
let inner = future.await?;
Ok(RateLimitGuard {
inner,
_guard: guard,
})
}
}
}

View file

@ -4,11 +4,12 @@ use crate::{
copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
},
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
};
use client::Client;
use collections::BTreeMap;
use gpui::{AppContext, Global, Model, ModelContext};
use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
use std::sync::Arc;
use ui::Context;
@ -70,9 +71,19 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
active_model: Option<ActiveModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
}
pub struct ActiveModel {
provider: Arc<dyn LanguageModelProvider>,
model: Option<Arc<dyn LanguageModel>>,
}
pub struct ActiveModelChanged;
impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
impl LanguageModelRegistry {
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<GlobalLanguageModelRegistry>().0.clone()
@ -88,6 +99,8 @@ impl LanguageModelRegistry {
let registry = cx.new_model(|cx| {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone();
registry.set_active_model(Some(model), cx);
registry
});
cx.set_global(GlobalLanguageModelRegistry(registry));
@ -136,6 +149,64 @@ impl LanguageModelRegistry {
) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(name).cloned()
}
pub fn select_active_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
cx: &mut ModelContext<Self>,
) {
let Some(provider) = self.provider(&provider) else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_active_model(Some(model), cx);
}
}
pub fn set_active_provider(
&mut self,
provider: Option<Arc<dyn LanguageModelProvider>>,
cx: &mut ModelContext<Self>,
) {
self.active_model = provider.map(|provider| ActiveModel {
provider,
model: None,
});
cx.emit(ActiveModelChanged);
}
pub fn set_active_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut ModelContext<Self>,
) {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.active_model = Some(ActiveModel {
provider,
model: Some(model),
});
cx.emit(ActiveModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
}
} else {
self.active_model = None;
cx.emit(ActiveModelChanged);
}
}
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
Some(self.active_model.as_ref()?.provider.clone())
}
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.active_model.as_ref()?.model.clone()
}
}
#[cfg(test)]

View file

@ -89,9 +89,15 @@ impl AnthropicSettingsContent {
models
.into_iter()
.filter_map(|model| match model {
anthropic::Model::Custom { name, max_tokens } => {
Some(provider::anthropic::AvailableModel { name, max_tokens })
}
anthropic::Model::Custom {
name,
max_tokens,
tool_override,
} => Some(provider::anthropic::AvailableModel {
name,
max_tokens,
tool_override,
}),
_ => None,
})
.collect()