Allow customization of the model used for tool calling (#15479)
We also eliminate the `completion` crate and moved its logic into `LanguageModelRegistry`. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
1bfea9d443
commit
99bc90a372
32 changed files with 478 additions and 691 deletions
|
@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
|
|||
use crate::{
|
||||
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::Client;
|
||||
|
@ -41,6 +41,7 @@ pub struct AvailableModel {
|
|||
provider: AvailableProvider,
|
||||
name: String,
|
||||
max_tokens: usize,
|
||||
tool_override: Option<String>,
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModelProvider {
|
||||
|
@ -56,7 +57,7 @@ struct State {
|
|||
}
|
||||
|
||||
impl State {
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||
}
|
||||
|
@ -142,6 +143,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
tool_override: model.tool_override.clone(),
|
||||
}),
|
||||
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
|
@ -162,6 +164,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
client: self.client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
|
@ -171,8 +174,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
self.state.read(cx).status.is_connected()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.state.read(cx).authenticate(cx)
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
|
@ -182,7 +185,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||
fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
@ -191,6 +194,7 @@ pub struct CloudLanguageModel {
|
|||
id: LanguageModelId,
|
||||
model: CloudModel,
|
||||
client: Arc<Client>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl LanguageModel for CloudLanguageModel {
|
||||
|
@ -256,7 +260,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_anthropic(model.id().into());
|
||||
async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
|
@ -266,15 +270,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.await?;
|
||||
Ok(anthropic::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_open_ai(model.id().into());
|
||||
async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
|
@ -284,15 +287,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
async move {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
|
@ -302,15 +304,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.await?;
|
||||
Ok(google_ai::extract_text_from_events(
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
tool_name: String,
|
||||
|
@ -321,7 +322,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let mut request = request.into_anthropic(model.id().into());
|
||||
let mut request = request.into_anthropic(model.tool_model_id().into());
|
||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||
name: tool_name.clone(),
|
||||
});
|
||||
|
@ -331,32 +332,34 @@ impl LanguageModel for CloudLanguageModel {
|
|||
input_schema,
|
||||
}];
|
||||
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request(proto::CompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request(proto::CompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let response: anthropic::Response =
|
||||
serde_json::from_str(&response.completion)?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.context("tool not used")
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.context("tool not used")
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue