Allow customization of the model used for tool calling (#15479)
We also eliminate the `completion` crate and moved its logic into `LanguageModelRegistry`. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
1bfea9d443
commit
99bc90a372
32 changed files with 478 additions and 691 deletions
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::BTreeMap;
|
||||
|
@ -36,6 +36,7 @@ pub struct AnthropicSettings {
|
|||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub max_tokens: usize,
|
||||
pub tool_override: Option<String>,
|
||||
}
|
||||
|
||||
pub struct AnthropicLanguageModelProvider {
|
||||
|
@ -98,6 +99,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||
anthropic::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
tool_override: model.tool_override.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -110,6 +112,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
|
@ -119,7 +122,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||
self.state.read(cx).api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
|
@ -152,7 +155,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
|||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
let state = self.state.clone();
|
||||
let delete_credentials =
|
||||
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
|
||||
|
@ -171,6 +174,7 @@ pub struct AnthropicModel {
|
|||
model: anthropic::Model,
|
||||
state: gpui::Model<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
pub fn count_anthropic_tokens(
|
||||
|
@ -296,14 +300,14 @@ impl LanguageModel for AnthropicModel {
|
|||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = request.into_anthropic(self.model.id().into());
|
||||
let request = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.await?;
|
||||
Ok(anthropic::extract_text_from_events(response).boxed())
|
||||
}
|
||||
.boxed()
|
||||
Ok(anthropic::extract_text_from_events(response))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
tool_name: String,
|
||||
|
@ -311,7 +315,7 @@ impl LanguageModel for AnthropicModel {
|
|||
input_schema: serde_json::Value,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
let mut request = request.into_anthropic(self.model.id().into());
|
||||
let mut request = request.into_anthropic(self.model.tool_model_id().into());
|
||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||
name: tool_name.clone(),
|
||||
});
|
||||
|
@ -322,25 +326,26 @@ impl LanguageModel for AnthropicModel {
|
|||
}];
|
||||
|
||||
let response = self.request_completion(request, cx);
|
||||
async move {
|
||||
let response = response.await?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = response.await?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.context("tool not used")
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.context("tool not used")
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
|
|||
use crate::{
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::Client;
|
||||
|
@ -41,6 +41,7 @@ pub struct AvailableModel {
|
|||
provider: AvailableProvider,
|
||||
name: String,
|
||||
max_tokens: usize,
|
||||
tool_override: Option<String>,
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModelProvider {
|
||||
|
@ -56,7 +57,7 @@ struct State {
|
|||
}
|
||||
|
||||
impl State {
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||
}
|
||||
|
@ -142,6 +143,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
tool_override: model.tool_override.clone(),
|
||||
}),
|
||||
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
|
@ -162,6 +164,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
client: self.client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
|
@ -171,8 +174,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
self.state.read(cx).status.is_connected()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.state.read(cx).authenticate(cx)
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
|
@ -182,7 +185,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
@ -191,6 +194,7 @@ pub struct CloudLanguageModel {
|
|||
id: LanguageModelId,
|
||||
model: CloudModel,
|
||||
client: Arc<Client>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl LanguageModel for CloudLanguageModel {
|
||||
|
@ -256,7 +260,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_anthropic(model.id().into());
|
||||
async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
|
@ -266,15 +270,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.await?;
|
||||
Ok(anthropic::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_open_ai(model.id().into());
|
||||
async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
|
@ -284,15 +287,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
|
@ -302,15 +304,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.await?;
|
||||
Ok(google_ai::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
tool_name: String,
|
||||
|
@ -321,7 +322,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let mut request = request.into_anthropic(model.id().into());
|
||||
let mut request = request.into_anthropic(model.tool_model_id().into());
|
||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||
name: tool_name.clone(),
|
||||
});
|
||||
|
@ -331,32 +332,34 @@ impl LanguageModel for CloudLanguageModel {
|
|||
input_schema,
|
||||
}];
|
||||
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request(proto::CompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request(proto::CompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let response: anthropic::Response =
|
||||
serde_json::from_str(&response.completion)?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.context("tool not used")
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.context("tool not used")
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
|
||||
|
|
|
@ -27,7 +27,7 @@ use crate::settings::AllLanguageModelSettings;
|
|||
use crate::LanguageModelProviderState;
|
||||
use crate::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
|
||||
use super::open_ai::count_open_ai_tokens;
|
||||
|
@ -85,7 +85,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||
|
||||
fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
CopilotChatModel::iter()
|
||||
.map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>)
|
||||
.map(|model| {
|
||||
Arc::new(CopilotChatLanguageModel {
|
||||
model,
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
@ -95,7 +100,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
let result = if self.is_authenticated(cx) {
|
||||
Ok(())
|
||||
} else if let Some(copilot) = Copilot::global(cx) {
|
||||
|
@ -121,7 +126,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||
cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
let Some(copilot) = Copilot::global(cx) else {
|
||||
return Task::ready(Err(anyhow::anyhow!(
|
||||
"Copilot is not available. Please ensure Copilot is enabled and running and try again."
|
||||
|
@ -145,6 +150,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||
|
||||
pub struct CopilotChatLanguageModel {
|
||||
model: CopilotChatModel,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl LanguageModel for CopilotChatLanguageModel {
|
||||
|
@ -215,30 +221,35 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||
return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
cx.spawn(|mut cx| async move {
|
||||
let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(result) => {
|
||||
let choice = result.choices.first();
|
||||
match choice {
|
||||
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
|
||||
None => Some(Err(anyhow::anyhow!(
|
||||
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
|
||||
))),
|
||||
let request_limiter = self.request_limiter.clone();
|
||||
let future = cx.spawn(|cx| async move {
|
||||
let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
|
||||
request_limiter.stream(async move {
|
||||
let response = response.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(result) => {
|
||||
let choice = result.choices.first();
|
||||
match choice {
|
||||
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
|
||||
None => Some(Err(anyhow::anyhow!(
|
||||
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
|
||||
))),
|
||||
}
|
||||
}
|
||||
Err(err) => Some(Err(err)),
|
||||
}
|
||||
Err(err) => Some(Err(err)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}).await
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
|
|
|
@ -60,7 +60,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
|
|||
true
|
||||
}
|
||||
|
||||
fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
|
@ -68,7 +68,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
|
|||
unimplemented!()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
|
||||
fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ impl LanguageModel for FakeLanguageModel {
|
|||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
|
|
|
@ -20,7 +20,7 @@ use util::ResultExt;
|
|||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
|
||||
};
|
||||
|
||||
const PROVIDER_ID: &str = "google";
|
||||
|
@ -111,6 +111,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
rate_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
|
@ -120,7 +121,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||
self.state.read(cx).api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
|
@ -153,7 +154,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
|||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
let state = self.state.clone();
|
||||
let delete_credentials =
|
||||
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
|
||||
|
@ -172,6 +173,7 @@ pub struct GoogleLanguageModel {
|
|||
model: google_ai::Model,
|
||||
state: gpui::Model<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
rate_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl LanguageModel for GoogleLanguageModel {
|
||||
|
@ -243,17 +245,17 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let future = self.rate_limiter.stream(async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let response =
|
||||
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let events = response.await?;
|
||||
Ok(google_ai::extract_text_from_events(events).boxed())
|
||||
}
|
||||
.boxed()
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
|
|
|
@ -12,7 +12,7 @@ use ui::{prelude::*, ButtonLike, ElevationIndex};
|
|||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
|
||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||
|
@ -39,7 +39,7 @@ struct State {
|
|||
}
|
||||
|
||||
impl State {
|
||||
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
|
||||
fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = settings.api_url.clone();
|
||||
|
@ -80,37 +80,10 @@ impl OllamaLanguageModelProvider {
|
|||
}),
|
||||
}),
|
||||
};
|
||||
this.fetch_models(cx).detach();
|
||||
this.state
|
||||
.update(cx, |state, cx| state.fetch_models(cx).detach());
|
||||
this
|
||||
}
|
||||
|
||||
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = settings.api_url.clone();
|
||||
|
||||
let state = self.state.clone();
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
cx.spawn(|mut cx| async move {
|
||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||
|
||||
let mut models: Vec<ollama::Model> = models
|
||||
.into_iter()
|
||||
// Since there is no metadata from the Ollama API
|
||||
// indicating which models are embedding models,
|
||||
// simply filter out models with "-embed" in their name
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
.map(|model| ollama::Model::new(&model.name))
|
||||
.collect();
|
||||
|
||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.available_models = models;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
||||
|
@ -140,6 +113,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||
id: LanguageModelId::from(model.name.clone()),
|
||||
model: model.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
|
@ -158,11 +132,11 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||
!self.state.read(cx).available_models.is_empty()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
self.fetch_models(cx)
|
||||
self.state.update(cx, |state, cx| state.fetch_models(cx))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,8 +150,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.fetch_models(cx)
|
||||
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.fetch_models(cx))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -185,6 +159,7 @@ pub struct OllamaLanguageModel {
|
|||
id: LanguageModelId,
|
||||
model: ollama::Model,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl OllamaLanguageModel {
|
||||
|
@ -235,14 +210,14 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("ollama/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
|
@ -275,10 +250,10 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let request =
|
||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
||||
let response = request.await?;
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response =
|
||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
|
||||
.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
|
@ -295,11 +270,12 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
|
|
|
@ -20,7 +20,7 @@ use util::ResultExt;
|
|||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
|
||||
const PROVIDER_ID: &str = "openai";
|
||||
|
@ -112,6 +112,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
|
@ -121,7 +122,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||
self.state.read(cx).api_key.is_some()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
|
@ -153,7 +154,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
|||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).openai;
|
||||
let delete_credentials = cx.delete_credentials(&settings.api_url);
|
||||
let state = self.state.clone();
|
||||
|
@ -172,6 +173,7 @@ pub struct OpenAiLanguageModel {
|
|||
model: open_ai::Model,
|
||||
state: gpui::Model<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAiLanguageModel {
|
||||
|
@ -226,7 +228,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
|
@ -237,11 +239,12 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
);
|
||||
let response = request.await?;
|
||||
Ok(open_ai::extract_text_from_events(response).boxed())
|
||||
}
|
||||
.boxed()
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue