Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
};
use anyhow::{anyhow, Context as _, Result};
use client::Client;
@ -41,6 +41,7 @@ pub struct AvailableModel {
provider: AvailableProvider,
name: String,
max_tokens: usize,
tool_override: Option<String>,
}
pub struct CloudLanguageModelProvider {
@ -56,7 +57,7 @@ struct State {
}
impl State {
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
}
@ -142,6 +143,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
}),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
@ -162,6 +164,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
id: LanguageModelId::from(model.id().to_string()),
model,
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -171,8 +174,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
self.state.read(cx).status.is_connected()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
self.state.read(cx).authenticate(cx)
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
@ -182,7 +185,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
.into()
}
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
@ -191,6 +194,7 @@ pub struct CloudLanguageModel {
id: LanguageModelId,
model: CloudModel,
client: Arc<Client>,
request_limiter: RateLimiter,
}
impl LanguageModel for CloudLanguageModel {
@ -256,7 +260,7 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let request = request.into_anthropic(model.id().into());
async move {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
@ -266,15 +270,14 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(anthropic::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
.boxed()
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = request.into_open_ai(model.id().into());
async move {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
@ -284,15 +287,14 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(open_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
.boxed()
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
async move {
let future = self.request_limiter.stream(async move {
let request = serde_json::to_string(&request)?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
@ -302,15 +304,14 @@ impl LanguageModel for CloudLanguageModel {
.await?;
Ok(google_ai::extract_text_from_events(
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
.boxed()
))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
}
fn use_tool(
fn use_any_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
@ -321,7 +322,7 @@ impl LanguageModel for CloudLanguageModel {
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let mut request = request.into_anthropic(model.id().into());
let mut request = request.into_anthropic(model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
@ -331,32 +332,34 @@ impl LanguageModel for CloudLanguageModel {
input_schema,
}];
async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response =
serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
})
.context("tool not used")
})
.boxed()
}
CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()