Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -1,7 +1,7 @@
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
@ -36,6 +36,7 @@ pub struct AnthropicSettings {
pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
pub tool_override: Option<String>,
}
pub struct AnthropicLanguageModelProvider {
@ -98,6 +99,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
anthropic::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
},
);
}
@ -110,6 +112,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -119,7 +122,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
@ -152,7 +155,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
@ -171,6 +174,7 @@ pub struct AnthropicModel {
model: anthropic::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
pub fn count_anthropic_tokens(
@ -296,14 +300,14 @@ impl LanguageModel for AnthropicModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = request.into_anthropic(self.model.id().into());
let request = self.stream_completion(request, cx);
async move {
let future = self.request_limiter.stream(async move {
let response = request.await?;
Ok(anthropic::extract_text_from_events(response).boxed())
}
.boxed()
Ok(anthropic::extract_text_from_events(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
@ -311,7 +315,7 @@ impl LanguageModel for AnthropicModel {
input_schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
let mut request = request.into_anthropic(self.model.id().into());
let mut request = request.into_anthropic(self.model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
@ -322,25 +326,26 @@ impl LanguageModel for AnthropicModel {
}];
let response = self.request_completion(request, cx);
async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
self.request_limiter
.run(async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
})
.context("tool not used")
})
.boxed()
}
}