Use tool calling instead of XML parsing to generate edit operations (#15385)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
f6012cd86e
commit
6e1f7c6e1d
22 changed files with 1155 additions and 853 deletions
|
@ -16,6 +16,8 @@ pub use model::*;
|
|||
pub use registry::*;
|
||||
pub use request::*;
|
||||
pub use role::*;
|
||||
use schemars::JsonSchema;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
settings::init(cx);
|
||||
|
@ -42,6 +44,20 @@ pub trait LanguageModel: Send + Sync {
|
|||
request: LanguageModelRequest,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
|
||||
fn use_tool(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
name: String,
|
||||
description: String,
|
||||
schema: serde_json::Value,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>>;
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||
fn name() -> String;
|
||||
fn description() -> String;
|
||||
}
|
||||
|
||||
pub trait LanguageModelProvider: 'static {
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
use anthropic::stream_completion;
|
||||
use anyhow::{anyhow, Result};
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::BTreeMap;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
|
@ -15,12 +19,6 @@ use theme::ThemeSettings;
|
|||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||
};
|
||||
|
||||
const PROVIDER_ID: &str = "anthropic";
|
||||
const PROVIDER_NAME: &str = "Anthropic";
|
||||
|
||||
|
@ -188,6 +186,61 @@ pub fn count_anthropic_tokens(
|
|||
.boxed()
|
||||
}
|
||||
|
||||
impl AnthropicModel {
|
||||
fn request_completion(
|
||||
&self,
|
||||
request: anthropic::Request,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<anthropic::Response>> {
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: anthropic::Request,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event>>>> {
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
||||
(
|
||||
state.api_key.clone(),
|
||||
settings.api_url.clone(),
|
||||
settings.low_speed_timeout,
|
||||
)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = anthropic::stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
request.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for AnthropicModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
|
@ -227,34 +280,53 @@ impl LanguageModel for AnthropicModel {
|
|||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = request.into_anthropic(self.model.id().into());
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
||||
(
|
||||
state.api_key.clone(),
|
||||
settings.api_url.clone(),
|
||||
settings.low_speed_timeout,
|
||||
)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let request = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
Ok(anthropic::extract_text_from_events(response).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
tool_name: String,
|
||||
tool_description: String,
|
||||
input_schema: serde_json::Value,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
let mut request = request.into_anthropic(self.model.id().into());
|
||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||
name: tool_name.clone(),
|
||||
});
|
||||
request.tools = vec![anthropic::Tool {
|
||||
name: tool_name.clone(),
|
||||
description: tool_description,
|
||||
input_schema,
|
||||
}];
|
||||
|
||||
let response = self.request_completion(request, cx);
|
||||
async move {
|
||||
let response = response.await?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.context("tool not used")
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use client::Client;
|
||||
use collections::BTreeMap;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
|
@ -12,7 +12,7 @@ use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
|
|||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use std::{future, sync::Arc};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
|
||||
|
@ -234,15 +234,13 @@ impl LanguageModel for CloudLanguageModel {
|
|||
};
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Google as i32,
|
||||
kind: proto::LanguageModelRequestKind::CountTokens as i32,
|
||||
request,
|
||||
});
|
||||
let response = response.await?;
|
||||
let response =
|
||||
serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
|
||||
Ok(response.total_tokens)
|
||||
let response = client
|
||||
.request(proto::CountLanguageModelTokens {
|
||||
provider: proto::LanguageModelProvider::Google as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(response.token_count as usize)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
@ -260,14 +258,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let request = request.into_anthropic(model.id().into());
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request_stream(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
||||
request,
|
||||
});
|
||||
let chunks = response.await?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(anthropic::extract_text_from_events(
|
||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
|
@ -278,14 +276,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let request = request.into_open_ai(model.id().into());
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request_stream(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
||||
request,
|
||||
});
|
||||
let chunks = response.await?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(open_ai::extract_text_from_events(
|
||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
|
@ -296,14 +294,14 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let request = request.into_google(model.id().into());
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request_stream(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Google as i32,
|
||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
||||
request,
|
||||
});
|
||||
let chunks = response.await?;
|
||||
let stream = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Google as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
Ok(google_ai::extract_text_from_events(
|
||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
||||
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
|
@ -311,6 +309,63 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
tool_name: String,
|
||||
tool_description: String,
|
||||
input_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let mut request = request.into_anthropic(model.id().into());
|
||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||
name: tool_name.clone(),
|
||||
});
|
||||
request.tools = vec![anthropic::Tool {
|
||||
name: tool_name.clone(),
|
||||
description: tool_description,
|
||||
input_schema,
|
||||
}];
|
||||
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request(proto::CompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.context("tool not used")
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
|
||||
}
|
||||
CloudModel::Google(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use collections::HashMap;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
|
||||
use crate::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
||||
use http_client::Result;
|
||||
use std::{
|
||||
future,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use ui::WindowContext;
|
||||
|
||||
pub fn language_model_id() -> LanguageModelId {
|
||||
|
@ -170,4 +172,15 @@ impl LanguageModel for FakeLanguageModel {
|
|||
.insert(serde_json::to_string(&request).unwrap(), tx);
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
_description: String,
|
||||
_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ use gpui::{
|
|||
};
|
||||
use http_client::HttpClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{future, sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
|
@ -238,6 +238,17 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
_description: String,
|
||||
_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
|
|
|
@ -6,7 +6,7 @@ use ollama::{
|
|||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||
};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{future, sync::Arc, time::Duration};
|
||||
use ui::{prelude::*, ButtonLike, ElevationIndex};
|
||||
|
||||
use crate::{
|
||||
|
@ -298,6 +298,17 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
_description: String,
|
||||
_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct DownloadOllamaMessage {
|
||||
|
|
|
@ -9,7 +9,7 @@ use gpui::{
|
|||
use http_client::HttpClient;
|
||||
use open_ai::stream_completion;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{future, sync::Arc, time::Duration};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
|
@ -225,6 +225,17 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
_description: String,
|
||||
_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
|
|
|
@ -106,19 +106,27 @@ impl LanguageModelRequest {
|
|||
messages: new_messages
|
||||
.into_iter()
|
||||
.filter_map(|message| {
|
||||
Some(anthropic::RequestMessage {
|
||||
Some(anthropic::Message {
|
||||
role: match message.role {
|
||||
Role::User => anthropic::Role::User,
|
||||
Role::Assistant => anthropic::Role::Assistant,
|
||||
Role::System => return None,
|
||||
},
|
||||
content: message.content,
|
||||
content: vec![anthropic::Content::Text {
|
||||
text: message.content,
|
||||
}],
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
max_tokens: 4092,
|
||||
system: system_message,
|
||||
system: Some(system_message),
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
stop_sequences: Vec::new(),
|
||||
temperature: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue