assistant: Use tools in other providers (#15803)

- [x] OpenAI
- [ ] ~Google~ Moved into a separate branch at:
https://github.com/zed-industries/zed/tree/tool-calls-in-google-ai I've
ran into issues with having the API digest our schema without tripping
over itself - the function call parameters are malformed and whatnot. We
can resume from that branch if needed.
- [x] Ollama
- [x] Cloud
- [ ] ~Copilot Chat (?)~

Release Notes:

- Added tool calling capabilities to OpenAI and Ollama models.
This commit is contained in:
Piotr Osiewicz 2024-08-06 15:45:47 +02:00 committed by GitHub
parent be514f23e1
commit 874f0c0712
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 392 additions and 64 deletions

View file

@ -4,7 +4,7 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anyhow::{anyhow, Context as _, Result};
use anyhow::{anyhow, bail, Context as _, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
@ -634,14 +634,143 @@ impl LanguageModel for CloudLanguageModel {
})
.boxed()
}
CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into());
let client = self.client.clone();
let mut function = open_ai::FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
};
let func = open_ai::ToolDefinition::Function {
function: function.clone(),
};
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
// Fill in description and params separately, as they're not needed for tool_choice field.
function.description = Some(tool_description);
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
}
CloudModel::Zed(_) => {
future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
CloudModel::Zed(model) => {
// All Zed models are OpenAI-based at the time of writing.
let mut request = request.into_open_ai(model.id().into());
let client = self.client.clone();
let mut function = open_ai::FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
};
let func = open_ai::ToolDefinition::Function {
function: function.clone(),
};
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
// Fill in description and params separately, as they're not needed for tool_choice field.
function.description = Some(tool_description);
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
}
}

View file

@ -1,12 +1,14 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient;
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
ChatResponseDelta, OllamaToolCall,
};
use serde_json::Value;
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt;
@ -184,6 +186,7 @@ impl OllamaLanguageModel {
},
Role::Assistant => ChatMessage::Assistant {
content: msg.content,
tool_calls: None,
},
Role::System => ChatMessage::System {
content: msg.content,
@ -198,8 +201,25 @@ impl OllamaLanguageModel {
temperature: Some(request.temperature),
..Default::default()
}),
tools: vec![],
}
}
fn request_completion(
&self,
request: ChatRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<ChatResponseDelta>> {
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
}
}
impl LanguageModel for OllamaLanguageModel {
@ -269,7 +289,7 @@ impl LanguageModel for OllamaLanguageModel {
Ok(delta) => {
let content = match delta.message {
ChatMessage::User { content } => content,
ChatMessage::Assistant { content } => content,
ChatMessage::Assistant { content, .. } => content,
ChatMessage::System { content } => content,
};
Some(Ok(content))
@ -286,13 +306,48 @@ impl LanguageModel for OllamaLanguageModel {
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
use ollama::{OllamaFunctionTool, OllamaTool};
let function = OllamaFunctionTool {
name: tool_name.clone(),
description: Some(tool_description),
parameters: Some(schema),
};
let tools = vec![OllamaTool::Function { function }];
let request = self.to_ollama_request(request).with_tools(tools);
let response = self.request_completion(request, cx);
self.request_limiter
.run(async move {
let response = response.await?;
let ChatMessage::Assistant {
tool_calls,
content,
} = response.message
else {
bail!("message does not have an assistant role");
};
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
for call in tool_calls {
let OllamaToolCall::Function(function) = call;
if function.name == tool_name {
return Ok(function.arguments);
}
}
} else if let Ok(args) = serde_json::from_str::<Value>(&content) {
// Parse content as arguments.
return Ok(args);
} else {
bail!("assistant message does not have any tool calls");
};
bail!("tool not used")
})
.boxed()
}
}

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
@ -7,11 +7,13 @@ use gpui::{
View, WhiteSpace,
};
use http_client::HttpClient;
use open_ai::stream_completion;
use open_ai::{
stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Indicator};
@ -206,6 +208,41 @@ pub struct OpenAiLanguageModel {
request_limiter: RateLimiter,
}
impl OpenAiLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@ -245,44 +282,68 @@ impl LanguageModel for OpenAiLanguageModel {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let request = request.into_open_ai(self.model.id().into());
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?;
Ok(open_ai::extract_text_from_events(response).boxed())
});
async move { Ok(future.await?.boxed()) }.boxed()
let completions = self.stream_completion(request, cx);
async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
}
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
let mut request = request.into_open_ai(self.model.id().into());
let mut function = FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
};
let func = ToolDefinition::Function {
function: function.clone(),
};
request.tool_choice = Some(ToolChoice::Other(func.clone()));
// Fill in description and params separately, as they're not needed for tool_choice field.
function.description = Some(tool_description);
function.parameters = Some(schema);
request.tools = vec![ToolDefinition::Function { function }];
let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
let mut response = response.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
}