Use tool calling instead of XML parsing to generate edit operations (#15385)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-29 16:42:08 +02:00 committed by GitHub
parent f6012cd86e
commit 6e1f7c6e1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1155 additions and 853 deletions

View file

@ -16,6 +16,8 @@ pub use model::*;
pub use registry::*;
pub use request::*;
pub use role::*;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings::init(cx);
@ -42,6 +44,20 @@ pub trait LanguageModel: Send + Sync {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn use_tool(
&self,
request: LanguageModelRequest,
name: String,
description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>>;
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
}
pub trait LanguageModelProvider: 'static {

View file

@ -1,5 +1,9 @@
use anthropic::stream_completion;
use anyhow::{anyhow, Result};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -15,12 +19,6 @@ use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
@ -188,6 +186,61 @@ pub fn count_anthropic_tokens(
.boxed()
}
impl AnthropicModel {
fn request_completion(
&self,
request: anthropic::Request,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<anthropic::Response>> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await
}
.boxed()
}
fn stream_completion(
&self,
request: anthropic::Request,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event>>>> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = anthropic::stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
request.await
}
.boxed()
}
}
impl LanguageModel for AnthropicModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@ -227,34 +280,53 @@ impl LanguageModel for AnthropicModel {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = request.into_anthropic(self.model.id().into());
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let request = self.stream_completion(request, cx);
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?;
Ok(anthropic::extract_text_from_events(response).boxed())
}
.boxed()
}
fn use_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
let mut request = request.into_anthropic(self.model.id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
request.tools = vec![anthropic::Tool {
name: tool_name.clone(),
description: tool_description,
input_schema,
}];
let response = self.request_completion(request, cx);
async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
}
}
struct AuthenticationPrompt {

View file

@ -4,7 +4,7 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
use anyhow::Result;
use anyhow::{anyhow, Context as _, Result};
use client::Client;
use collections::BTreeMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -12,7 +12,7 @@ use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@ -234,15 +234,13 @@ impl LanguageModel for CloudLanguageModel {
};
async move {
let request = serde_json::to_string(&request)?;
let response = client.request(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::CountTokens as i32,
request,
});
let response = response.await?;
let response =
serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
Ok(response.total_tokens)
let response = client
.request(proto::CountLanguageModelTokens {
provider: proto::LanguageModelProvider::Google as i32,
request,
})
.await?;
Ok(response.token_count as usize)
}
.boxed()
}
@ -260,14 +258,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_anthropic(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
Ok(anthropic::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@ -278,14 +276,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_open_ai(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
Ok(open_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@ -296,14 +294,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_google(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
request,
})
.await?;
Ok(google_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@ -311,6 +309,63 @@ impl LanguageModel for CloudLanguageModel {
}
}
}
fn use_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let mut request = request.into_anthropic(model.id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
request.tools = vec![anthropic::Tool {
name: tool_name.clone(),
description: tool_description,
input_schema,
}];
async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
}
CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
}
}
}
}
struct AuthenticationPrompt {

View file

@ -1,15 +1,17 @@
use std::sync::{Arc, Mutex};
use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
};
use anyhow::anyhow;
use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result;
use std::{
future,
sync::{Arc, Mutex},
};
use ui::WindowContext;
pub fn language_model_id() -> LanguageModelId {
@ -170,4 +172,15 @@ impl LanguageModel for FakeLanguageModel {
.insert(serde_json::to_string(&request).unwrap(), tx);
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}

View file

@ -9,7 +9,7 @@ use gpui::{
};
use http_client::HttpClient;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
@ -238,6 +238,17 @@ impl LanguageModel for GoogleLanguageModel {
}
.boxed()
}
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
struct AuthenticationPrompt {

View file

@ -6,7 +6,7 @@ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::{future, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
@ -298,6 +298,17 @@ impl LanguageModel for OllamaLanguageModel {
}
.boxed()
}
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
struct DownloadOllamaMessage {

View file

@ -9,7 +9,7 @@ use gpui::{
use http_client::HttpClient;
use open_ai::stream_completion;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
@ -225,6 +225,17 @@ impl LanguageModel for OpenAiLanguageModel {
}
.boxed()
}
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
pub fn count_open_ai_tokens(

View file

@ -106,19 +106,27 @@ impl LanguageModelRequest {
messages: new_messages
.into_iter()
.filter_map(|message| {
Some(anthropic::RequestMessage {
Some(anthropic::Message {
role: match message.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => return None,
},
content: message.content,
content: vec![anthropic::Content::Text {
text: message.content,
}],
})
})
.collect(),
stream: true,
max_tokens: 4092,
system: system_message,
system: Some(system_message),
tools: Vec::new(),
tool_choice: None,
metadata: None,
stop_sequences: Vec::new(),
temperature: None,
top_k: None,
top_p: None,
}
}
}