Use tool calling instead of XML parsing to generate edit operations (#15385)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-29 16:42:08 +02:00 committed by GitHub
parent f6012cd86e
commit 6e1f7c6e1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1155 additions and 853 deletions

View file

@ -4,7 +4,7 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
use anyhow::Result;
use anyhow::{anyhow, Context as _, Result};
use client::Client;
use collections::BTreeMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@ -12,7 +12,7 @@ use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@ -234,15 +234,13 @@ impl LanguageModel for CloudLanguageModel {
};
async move {
let request = serde_json::to_string(&request)?;
let response = client.request(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::CountTokens as i32,
request,
});
let response = response.await?;
let response =
serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
Ok(response.total_tokens)
let response = client
.request(proto::CountLanguageModelTokens {
provider: proto::LanguageModelProvider::Google as i32,
request,
})
.await?;
Ok(response.token_count as usize)
}
.boxed()
}
@ -260,14 +258,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_anthropic(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
Ok(anthropic::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@ -278,14 +276,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_open_ai(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
Ok(open_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@ -296,14 +294,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_google(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
request,
})
.await?;
Ok(google_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@ -311,6 +309,63 @@ impl LanguageModel for CloudLanguageModel {
}
}
}
fn use_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let mut request = request.into_anthropic(model.id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
request.tools = vec![anthropic::Tool {
name: tool_name.clone(),
description: tool_description,
input_schema,
}];
async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
}
CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
}
}
}
}
struct AuthenticationPrompt {