Use tool calling instead of XML parsing to generate edit operations (#15385)

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-29 16:42:08 +02:00 committed by GitHub
parent f6012cd86e
commit 6e1f7c6e1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1155 additions and 853 deletions

View file

@ -10,7 +10,7 @@ use crate::{
ServerId, UpdatedChannelMessage, User, UserId,
},
executor::Executor,
AppState, Error, RateLimit, RateLimiter, Result,
AppState, Config, Error, RateLimit, RateLimiter, Result,
};
use anyhow::{anyhow, bail, Context as _};
use async_tungstenite::tungstenite::{
@ -605,17 +605,39 @@ impl Server {
))
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
complete_with_language_model(request, response, session, &app_state.config)
.await
}
}
})
.add_streaming_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
complete_with_language_model(
request,
response,
session,
app_state.config.openai_api_key.clone(),
app_state.config.google_ai_api_key.clone(),
app_state.config.anthropic_api_key.clone(),
)
let app_state = app_state.clone();
async move {
stream_complete_with_language_model(
request,
response,
session,
&app_state.config,
)
.await
}
}
})
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
count_language_model_tokens(request, response, session, &app_state.config)
.await
}
}
})
.add_request_handler({
@ -4503,103 +4525,119 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
query: proto::QueryLanguageModel,
response: StreamingResponse<proto::QueryLanguageModel>,
request: proto::CompleteWithLanguageModel,
response: Response<proto::CompleteWithLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
anthropic_api_key: Option<Arc<str>>,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
match proto::LanguageModelRequestKind::from_i32(query.kind) {
Some(proto::LanguageModelRequestKind::Complete) => {
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
}
Some(proto::LanguageModelRequestKind::CountTokens) => {
session
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
.await?;
}
None => Err(anyhow!("unknown request kind"))?,
}
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
match proto::LanguageModelProvider::from_i32(query.provider) {
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key =
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
let api_key = config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
anthropic::complete(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CompleteWithLanguageModelResponse {
completion: serde_json::to_string(&result)?,
})?;
Ok(())
}
async fn stream_complete_with_language_model(
request: proto::StreamCompleteWithLanguageModel,
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
session: Session,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key = config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
let mut chunks = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
&api_key,
serde_json::from_str(&query.request)?,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
while let Some(event) = chunks.next().await {
let chunk = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&chunk)?,
})?;
}
}
Some(proto::LanguageModelProvider::OpenAi) => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
let mut chunks = open_ai::stream_completion(
let api_key = config
.openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let mut events = open_ai::stream_completion(
session.http_client.as_ref(),
open_ai::OPEN_AI_API_URL,
&api_key,
serde_json::from_str(&query.request)?,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
Some(proto::LanguageModelProvider::Google) => {
let api_key =
google_ai_api_key.context("no Google AI API key configured on the server")?;
match proto::LanguageModelRequestKind::from_i32(query.kind) {
Some(proto::LanguageModelRequestKind::Complete) => {
let mut chunks = google_ai::stream_generate_content(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
serde_json::from_str(&query.request)?,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
})?;
}
}
Some(proto::LanguageModelRequestKind::CountTokens) => {
let tokens_response = google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
serde_json::from_str(&query.request)?,
)
.await?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&tokens_response)?,
})?;
}
None => Err(anyhow!("unknown request kind"))?,
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
let mut events = google_ai::stream_generate_content(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?;
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
None => return Err(anyhow!("unknown provider"))?,
@ -4608,11 +4646,51 @@ async fn complete_with_language_model(
Ok(())
}
struct CountTokensWithLanguageModelRateLimit;
async fn count_language_model_tokens(
request: proto::CountLanguageModelTokens,
response: Response<proto::CountLanguageModelTokens>,
session: Session,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
impl RateLimit for CountTokensWithLanguageModelRateLimit {
session
.rate_limiter
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Google) => {
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CountLanguageModelTokensResponse {
token_count: result.total_tokens as u32,
})?;
Ok(())
}
struct CountLanguageModelTokensRateLimit;
impl RateLimit for CountLanguageModelTokensRateLimit {
fn capacity() -> usize {
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
@ -4623,7 +4701,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
}
fn db_name() -> &'static str {
"count-tokens-with-language-model"
"count-language-model-tokens"
}
}