Simplify LLM protocol (#15366)

In this pull request, we change the zed.dev protocol so that we pass the
raw JSON for the specified provider directly to our server. This avoids
the need to define a protobuf message that's a superset of all these
formats.

@bennetbo: We also changed the settings for available_models under
zed.dev to be a flat format, because the nesting seemed too confusing.
Can you help us upgrade the local provider configuration to be
consistent with this? We do whatever we need to do when parsing the
settings to make this simple for users, even if it's a bit more complex
on our end. We want to use versioning to avoid breaking existing users,
but need to keep making progress.

```json
"zed.dev": {
  "available_models": [
    {
      "provider": "anthropic",
        "name": "some-newly-released-model-we-havent-added",
        "max_tokens": 200000
      }
  ]
}
```

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-28 11:07:10 +02:00 committed by GitHub
parent e0fe7f632c
commit d6bdaa8a91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 896 additions and 2154 deletions

View file

@ -1,138 +0,0 @@
use anyhow::{anyhow, Context as _, Result};
use rpc::proto;
use util::ResultExt as _;
pub fn language_model_request_to_open_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<open_ai::Request> {
Ok(open_ai::Request {
model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
messages: request
.messages
.into_iter()
.map(|message: proto::LanguageModelRequestMessage| {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
let openai_message = match role {
proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User {
content: message.content,
},
proto::LanguageModelRole::LanguageModelAssistant => {
open_ai::RequestMessage::Assistant {
content: Some(message.content),
tool_calls: message
.tool_calls
.into_iter()
.filter_map(|call| {
Some(open_ai::ToolCall {
id: call.id,
content: match call.variant? {
proto::tool_call::Variant::Function(f) => {
open_ai::ToolCallContent::Function {
function: open_ai::FunctionContent {
name: f.name,
arguments: f.arguments,
},
}
}
},
})
})
.collect(),
}
}
proto::LanguageModelRole::LanguageModelSystem => {
open_ai::RequestMessage::System {
content: message.content,
}
}
proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool {
tool_call_id: message
.tool_call_id
.ok_or_else(|| anyhow!("tool message is missing tool call id"))?,
content: message.content,
},
};
Ok(openai_message)
})
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: request
.tools
.into_iter()
.filter_map(|tool| {
Some(match tool.variant? {
proto::chat_completion_tool::Variant::Function(f) => {
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
name: f.name,
description: f.description,
parameters: if let Some(params) = &f.parameters {
Some(
serde_json::from_str(params)
.context("failed to deserialize tool parameters")
.log_err()?,
)
} else {
None
},
},
}
}
})
})
.collect(),
tool_choice: request.tool_choice,
})
}
pub fn language_model_request_to_google_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<google_ai::GenerateContentRequest> {
Ok(google_ai::GenerateContentRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
generation_config: None,
safety_settings: None,
})
}
pub fn language_model_request_message_to_google_ai(
message: proto::LanguageModelRequestMessage,
) -> Result<google_ai::Content> {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
Ok(google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: message.content,
})],
role: match role {
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelTool => {
Err(anyhow!("we don't handle tool calls with google ai yet"))?
}
},
})
}
pub fn count_tokens_request_to_google_ai(
request: proto::CountTokensWithLanguageModel,
) -> Result<google_ai::CountTokensRequest> {
Ok(google_ai::CountTokensRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
})
}

View file

@ -1,4 +1,3 @@
pub mod ai;
pub mod api;
pub mod auth;
pub mod db;

View file

@ -46,8 +46,8 @@ use http_client::IsahcHttpClient;
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
},
Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
};
@ -618,17 +618,6 @@ impl Server {
)
}
})
.add_request_handler({
let app_state = app_state.clone();
user_handler(move |request, response, session| {
count_tokens_with_language_model(
request,
response,
session,
app_state.config.google_ai_api_key.clone(),
)
})
})
.add_request_handler({
user_handler(move |request, response, session| {
get_cached_embeddings(request, response, session)
@ -4514,8 +4503,8 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
mut request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
query: proto::QueryLanguageModel,
response: StreamingResponse<proto::QueryLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
@ -4525,287 +4514,95 @@ async fn complete_with_language_model(
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
let mut provider_and_model = request.model.split('/');
let (provider, model) = match (
provider_and_model.next().unwrap(),
provider_and_model.next(),
) {
(provider, Some(model)) => (provider, model),
(model, None) => {
if model.starts_with("gpt") {
("openai", model)
} else if model.starts_with("gemini") {
("google", model)
} else if model.starts_with("claude") {
("anthropic", model)
} else {
("unknown", model)
}
match proto::LanguageModelRequestKind::from_i32(query.kind) {
Some(proto::LanguageModelRequestKind::Complete) => {
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
}
};
let provider = provider.to_string();
request.model = model.to_string();
Some(proto::LanguageModelRequestKind::CountTokens) => {
session
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
.await?;
}
None => Err(anyhow!("unknown request kind"))?,
}
match provider.as_str() {
"openai" => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
complete_with_open_ai(request, response, session, api_key).await?;
}
"anthropic" => {
match proto::LanguageModelProvider::from_i32(query.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key =
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
complete_with_anthropic(request, response, session, api_key).await?;
let mut chunks = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
&api_key,
serde_json::from_str(&query.request)?,
None,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
})?;
}
}
"google" => {
Some(proto::LanguageModelProvider::OpenAi) => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
let mut chunks = open_ai::stream_completion(
session.http_client.as_ref(),
open_ai::OPEN_AI_API_URL,
&api_key,
serde_json::from_str(&query.request)?,
None,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
})?;
}
}
Some(proto::LanguageModelProvider::Google) => {
let api_key =
google_ai_api_key.context("no Google AI API key configured on the server")?;
complete_with_google_ai(request, response, session, api_key).await?;
}
provider => return Err(anyhow!("unknown provider {:?}", provider))?,
}
Ok(())
}
async fn complete_with_open_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let mut completion_stream = open_ai::stream_completion(
session.http_client.as_ref(),
OPEN_AI_API_URL,
&api_key,
crate::ai::language_model_request_to_open_ai(request)?,
None,
)
.await
.context("open_ai::stream_completion request failed within collab")?;
while let Some(event) = completion_stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.choices
.into_iter()
.map(|choice| proto::LanguageModelChoiceDelta {
index: choice.index,
delta: Some(proto::LanguageModelResponseMessage {
role: choice.delta.role.map(|role| match role {
open_ai::Role::User => LanguageModelRole::LanguageModelUser,
open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
} as i32),
content: choice.delta.content,
tool_calls: choice
.delta
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|delta| proto::ToolCallDelta {
index: delta.index as u32,
id: delta.id,
variant: match delta.function {
Some(function) => {
let name = function.name;
let arguments = function.arguments;
Some(proto::tool_call_delta::Variant::Function(
proto::tool_call_delta::FunctionCallDelta {
name,
arguments,
},
))
}
None => None,
},
})
.collect(),
}),
finish_reason: choice.finish_reason,
})
.collect(),
})?;
}
Ok(())
}
async fn complete_with_google_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let mut stream = google_ai::stream_generate_content(
session.http_client.clone(),
google_ai::API_URL,
api_key.as_ref(),
&request.model.clone(),
crate::ai::language_model_request_to_google_ai(request)?,
)
.await
.context("google_ai::stream_generate_content request failed")?;
while let Some(event) = stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.candidates
.unwrap_or_default()
.into_iter()
.map(|candidate| proto::LanguageModelChoiceDelta {
index: candidate.index as u32,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(match candidate.content.role {
google_ai::Role::User => LanguageModelRole::LanguageModelUser,
google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
} as i32),
content: Some(
candidate
.content
.parts
.into_iter()
.filter_map(|part| match part {
google_ai::Part::TextPart(part) => Some(part.text),
google_ai::Part::InlineDataPart(_) => None,
})
.collect(),
),
// Tool calls are not supported for Google
tool_calls: Vec::new(),
}),
finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
})
.collect(),
})?;
}
Ok(())
}
async fn complete_with_anthropic(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let mut system_message = String::new();
let messages = request
.messages
.into_iter()
.filter_map(|message| {
match message.role() {
LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
role: anthropic::Role::User,
content: message.content,
}),
LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
role: anthropic::Role::Assistant,
content: message.content,
}),
// Anthropic's API breaks system instructions out as a separate field rather
// than having a system message role.
LanguageModelRole::LanguageModelSystem => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
None
}
// We don't yet support tool calls for Anthropic
LanguageModelRole::LanguageModelTool => None,
}
})
.collect();
let mut stream = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
&api_key,
anthropic::Request {
model: request.model,
messages,
stream: true,
system: system_message,
max_tokens: 4092,
},
None,
)
.await?;
let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
while let Some(event) = stream.next().await {
let event = event?;
match event {
anthropic::ResponseEvent::MessageStart { message } => {
if let Some(role) = message.role {
if role == "assistant" {
current_role = proto::LanguageModelRole::LanguageModelAssistant;
} else if role == "user" {
current_role = proto::LanguageModelRole::LanguageModelUser;
match proto::LanguageModelRequestKind::from_i32(query.kind) {
Some(proto::LanguageModelRequestKind::Complete) => {
let mut chunks = google_ai::stream_generate_content(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
serde_json::from_str(&query.request)?,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
})?;
}
}
}
anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
match content_block {
anthropic::ContentBlock::Text { text } => {
if !text.is_empty() {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
tool_calls: Vec::new(),
}),
finish_reason: None,
}],
})?;
}
}
}
}
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
anthropic::TextDelta::TextDelta { text } => {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
tool_calls: Vec::new(),
}),
finish_reason: None,
}],
})?;
}
},
anthropic::ResponseEvent::MessageDelta { delta, .. } => {
if let Some(stop_reason) = delta.stop_reason {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: None,
finish_reason: Some(stop_reason),
}],
Some(proto::LanguageModelRequestKind::CountTokens) => {
let tokens_response = google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
serde_json::from_str(&query.request)?,
)
.await?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&tokens_response)?,
})?;
}
None => Err(anyhow!("unknown request kind"))?,
}
anthropic::ResponseEvent::ContentBlockStop { .. } => {}
anthropic::ResponseEvent::MessageStop {} => {}
anthropic::ResponseEvent::Ping {} => {}
}
None => return Err(anyhow!("unknown provider"))?,
}
Ok(())
@ -4830,41 +4627,6 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
}
}
async fn count_tokens_with_language_model(
request: proto::CountTokensWithLanguageModel,
response: Response<proto::CountTokensWithLanguageModel>,
session: UserSession,
google_ai_api_key: Option<Arc<str>>,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
if !request.model.starts_with("gemini") {
return Err(anyhow!(
"counting tokens for model: {:?} is not supported",
request.model
))?;
}
session
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
.await?;
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
let tokens_response = google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
crate::ai::count_tokens_request_to_google_ai(request)?,
)
.await?;
response.send(proto::CountTokensResponse {
token_count: tokens_response.total_tokens as u32,
})?;
Ok(())
}
struct ComputeEmbeddingsRateLimit;
impl RateLimit for ComputeEmbeddingsRateLimit {