assistant: Use tools in other providers (#15803)
- [x] OpenAI - [ ] ~Google~ Moved into a separate branch at: https://github.com/zed-industries/zed/tree/tool-calls-in-google-ai I've ran into issues with having the API digest our schema without tripping over itself - the function call parameters are malformed and whatnot. We can resume from that branch if needed. - [x] Ollama - [x] Cloud - [ ] ~Copilot Chat (?)~ Release Notes: - Added tool calling capabilities to OpenAI and Ollama models.
This commit is contained in:
parent
be514f23e1
commit
874f0c0712
5 changed files with 392 additions and 64 deletions
|
@ -4,7 +4,7 @@ use crate::{
|
|||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use anyhow::{anyhow, bail, Context as _, Result};
|
||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use collections::BTreeMap;
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
|
||||
|
@ -634,14 +634,143 @@ impl LanguageModel for CloudLanguageModel {
|
|||
})
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
|
||||
CloudModel::OpenAi(model) => {
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
let client = self.client.clone();
|
||||
let mut function = open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
};
|
||||
let func = open_ai::ToolDefinition::Function {
|
||||
function: function.clone(),
|
||||
};
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||
function.description = Some(tool_description);
|
||||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
// Call arguments are gonna be streamed in over multiple chunks.
|
||||
let mut load_state = None;
|
||||
let mut response = response.map(
|
||||
|item: Result<
|
||||
proto::StreamCompleteWithLanguageModelResponse,
|
||||
anyhow::Error,
|
||||
>| {
|
||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||
serde_json::from_str(&item?.event)?,
|
||||
)
|
||||
},
|
||||
);
|
||||
while let Some(Ok(part)) = response.next().await {
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::Google(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||
}
|
||||
CloudModel::Zed(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
|
||||
CloudModel::Zed(model) => {
|
||||
// All Zed models are OpenAI-based at the time of writing.
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
let client = self.client.clone();
|
||||
let mut function = open_ai::FunctionDefinition {
|
||||
name: tool_name.clone(),
|
||||
description: None,
|
||||
parameters: None,
|
||||
};
|
||||
let func = open_ai::ToolDefinition::Function {
|
||||
function: function.clone(),
|
||||
};
|
||||
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
||||
function.description = Some(tool_description);
|
||||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
// Call arguments are gonna be streamed in over multiple chunks.
|
||||
let mut load_state = None;
|
||||
let mut response = response.map(
|
||||
|item: Result<
|
||||
proto::StreamCompleteWithLanguageModelResponse,
|
||||
anyhow::Error,
|
||||
>| {
|
||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||
serde_json::from_str(&item?.event)?,
|
||||
)
|
||||
},
|
||||
);
|
||||
while let Some(Ok(part)) = response.next().await {
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue