Use tool calling instead of XML parsing to generate edit operations (#15385)
Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
f6012cd86e
commit
6e1f7c6e1d
22 changed files with 1155 additions and 853 deletions
|
@ -10,7 +10,7 @@ use crate::{
|
|||
ServerId, UpdatedChannelMessage, User, UserId,
|
||||
},
|
||||
executor::Executor,
|
||||
AppState, Error, RateLimit, RateLimiter, Result,
|
||||
AppState, Config, Error, RateLimit, RateLimiter, Result,
|
||||
};
|
||||
use anyhow::{anyhow, bail, Context as _};
|
||||
use async_tungstenite::tungstenite::{
|
||||
|
@ -605,17 +605,39 @@ impl Server {
|
|||
))
|
||||
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
||||
.add_message_handler(update_context)
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
complete_with_language_model(request, response, session, &app_state.config)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_streaming_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
complete_with_language_model(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
app_state.config.openai_api_key.clone(),
|
||||
app_state.config.google_ai_api_key.clone(),
|
||||
app_state.config.anthropic_api_key.clone(),
|
||||
)
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
stream_complete_with_language_model(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
&app_state.config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
count_language_model_tokens(request, response, session, &app_state.config)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_request_handler({
|
||||
|
@ -4503,103 +4525,119 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
|
|||
}
|
||||
|
||||
async fn complete_with_language_model(
|
||||
query: proto::QueryLanguageModel,
|
||||
response: StreamingResponse<proto::QueryLanguageModel>,
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: Response<proto::CompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
open_ai_api_key: Option<Arc<str>>,
|
||||
google_ai_api_key: Option<Arc<str>>,
|
||||
anthropic_api_key: Option<Arc<str>>,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
match proto::LanguageModelRequestKind::from_i32(query.kind) {
|
||||
Some(proto::LanguageModelRequestKind::Complete) => {
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.await?;
|
||||
}
|
||||
Some(proto::LanguageModelRequestKind::CountTokens) => {
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
|
||||
.await?;
|
||||
}
|
||||
None => Err(anyhow!("unknown request kind"))?,
|
||||
}
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.await?;
|
||||
|
||||
match proto::LanguageModelProvider::from_i32(query.provider) {
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||
let api_key =
|
||||
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
|
||||
let api_key = config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
anthropic::complete(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
_ => return Err(anyhow!("unsupported provider"))?,
|
||||
};
|
||||
|
||||
response.send(proto::CompleteWithLanguageModelResponse {
|
||||
completion: serde_json::to_string(&result)?,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stream_complete_with_language_model(
|
||||
request: proto::StreamCompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.await?;
|
||||
|
||||
match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||
let api_key = config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
let mut chunks = anthropic::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
&api_key,
|
||||
serde_json::from_str(&query.request)?,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
let chunk = chunk?;
|
||||
response.send(proto::QueryLanguageModelResponse {
|
||||
response: serde_json::to_string(&chunk)?,
|
||||
while let Some(event) = chunks.next().await {
|
||||
let chunk = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&chunk)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::OpenAi) => {
|
||||
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
|
||||
let mut chunks = open_ai::stream_completion(
|
||||
let api_key = config
|
||||
.openai_api_key
|
||||
.as_ref()
|
||||
.context("no OpenAI API key configured on the server")?;
|
||||
let mut events = open_ai::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
open_ai::OPEN_AI_API_URL,
|
||||
&api_key,
|
||||
serde_json::from_str(&query.request)?,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
let chunk = chunk?;
|
||||
response.send(proto::QueryLanguageModelResponse {
|
||||
response: serde_json::to_string(&chunk)?,
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::Google) => {
|
||||
let api_key =
|
||||
google_ai_api_key.context("no Google AI API key configured on the server")?;
|
||||
|
||||
match proto::LanguageModelRequestKind::from_i32(query.kind) {
|
||||
Some(proto::LanguageModelRequestKind::Complete) => {
|
||||
let mut chunks = google_ai::stream_generate_content(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
&api_key,
|
||||
serde_json::from_str(&query.request)?,
|
||||
)
|
||||
.await?;
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
let chunk = chunk?;
|
||||
response.send(proto::QueryLanguageModelResponse {
|
||||
response: serde_json::to_string(&chunk)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelRequestKind::CountTokens) => {
|
||||
let tokens_response = google_ai::count_tokens(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
&api_key,
|
||||
serde_json::from_str(&query.request)?,
|
||||
)
|
||||
.await?;
|
||||
response.send(proto::QueryLanguageModelResponse {
|
||||
response: serde_json::to_string(&tokens_response)?,
|
||||
})?;
|
||||
}
|
||||
None => Err(anyhow!("unknown request kind"))?,
|
||||
let api_key = config
|
||||
.google_ai_api_key
|
||||
.as_ref()
|
||||
.context("no Google AI API key configured on the server")?;
|
||||
let mut events = google_ai::stream_generate_content(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
None => return Err(anyhow!("unknown provider"))?,
|
||||
|
@ -4608,11 +4646,51 @@ async fn complete_with_language_model(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
struct CountTokensWithLanguageModelRateLimit;
|
||||
async fn count_language_model_tokens(
|
||||
request: proto::CountLanguageModelTokens,
|
||||
response: Response<proto::CountLanguageModelTokens>,
|
||||
session: Session,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
|
||||
.await?;
|
||||
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Google) => {
|
||||
let api_key = config
|
||||
.google_ai_api_key
|
||||
.as_ref()
|
||||
.context("no Google AI API key configured on the server")?;
|
||||
google_ai::count_tokens(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
_ => return Err(anyhow!("unsupported provider"))?,
|
||||
};
|
||||
|
||||
response.send(proto::CountLanguageModelTokensResponse {
|
||||
token_count: result.total_tokens as u32,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CountLanguageModelTokensRateLimit;
|
||||
|
||||
impl RateLimit for CountLanguageModelTokensRateLimit {
|
||||
fn capacity() -> usize {
|
||||
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
||||
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(600) // Picked arbitrarily
|
||||
|
@ -4623,7 +4701,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
|||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"count-tokens-with-language-model"
|
||||
"count-language-model-tokens"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue