Simplify LLM protocol (#15366)
In this pull request, we change the zed.dev protocol so that we pass the raw JSON for the specified provider directly to our server. This avoids the need to define a protobuf message that's a superset of all these formats. @bennetbo: We also changed the settings for available_models under zed.dev to be a flat format, because the nesting seemed too confusing. Can you help us upgrade the local provider configuration to be consistent with this? We do whatever we need to do when parsing the settings to make this simple for users, even if it's a bit more complex on our end. We want to use versioning to avoid breaking existing users, but need to keep making progress. ```json "zed.dev": { "available_models": [ { "provider": "anthropic", "name": "some-newly-released-model-we-havent-added", "max_tokens": 200000 } ] } ``` Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
e0fe7f632c
commit
d6bdaa8a91
31 changed files with 896 additions and 2154 deletions
|
@ -7,8 +7,10 @@ use crate::{
|
|||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use collections::BTreeMap;
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
|
@ -16,14 +18,29 @@ use ui::prelude::*;
|
|||
|
||||
use crate::LanguageModelProvider;
|
||||
|
||||
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
|
||||
pub const PROVIDER_ID: &str = "zed.dev";
|
||||
pub const PROVIDER_NAME: &str = "zed.dev";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct ZedDotDevSettings {
|
||||
pub available_models: Vec<CloudModel>,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AvailableProvider {
|
||||
Anthropic,
|
||||
OpenAi,
|
||||
Google,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
provider: AvailableProvider,
|
||||
name: String,
|
||||
max_tokens: usize,
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModelProvider {
|
||||
|
@ -100,10 +117,19 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
// Add base models from CloudModel::iter()
|
||||
for model in CloudModel::iter() {
|
||||
if !matches!(model, CloudModel::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
for model in anthropic::Model::iter() {
|
||||
if !matches!(model, anthropic::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), CloudModel::Anthropic(model));
|
||||
}
|
||||
}
|
||||
for model in open_ai::Model::iter() {
|
||||
if !matches!(model, open_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), CloudModel::OpenAi(model));
|
||||
}
|
||||
}
|
||||
for model in google_ai::Model::iter() {
|
||||
if !matches!(model, google_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), CloudModel::Google(model));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,6 +138,20 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
.zed_dot_dev
|
||||
.available_models
|
||||
{
|
||||
let model = match model.provider {
|
||||
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
}),
|
||||
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
}),
|
||||
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
}),
|
||||
};
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
|
@ -183,35 +223,26 @@ impl LanguageModel for CloudLanguageModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match &self.model {
|
||||
CloudModel::Gpt3Point5Turbo => {
|
||||
count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
|
||||
}
|
||||
CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
|
||||
CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
|
||||
CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
|
||||
CloudModel::Gpt4OmniMini => {
|
||||
count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
|
||||
}
|
||||
CloudModel::Claude3_5Sonnet
|
||||
| CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
|
||||
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
count_anthropic_tokens(request, cx)
|
||||
}
|
||||
_ => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model: self.model.id().to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
});
|
||||
match self.model.clone() {
|
||||
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
|
||||
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
let request = google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
};
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(response.token_count as usize)
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Google as i32,
|
||||
kind: proto::LanguageModelRequestKind::CountTokens as i32,
|
||||
request,
|
||||
});
|
||||
let response = response.await?;
|
||||
let response =
|
||||
serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
|
||||
Ok(response.total_tokens)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
@ -220,46 +251,65 @@ impl LanguageModel for CloudLanguageModel {
|
|||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
mut request: LanguageModelRequest,
|
||||
request: LanguageModelRequest,
|
||||
_: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match &self.model {
|
||||
CloudModel::Claude3Opus
|
||||
| CloudModel::Claude3Sonnet
|
||||
| CloudModel::Claude3Haiku
|
||||
| CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
|
||||
CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
|
||||
preprocess_anthropic_request(&mut request)
|
||||
CloudModel::Anthropic(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_anthropic(model.id().into());
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request_stream(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
||||
request,
|
||||
});
|
||||
let chunks = response.await?;
|
||||
Ok(anthropic::extract_text_from_events(
|
||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_open_ai(model.id().into());
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request_stream(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
||||
request,
|
||||
});
|
||||
let chunks = response.await?;
|
||||
Ok(open_ai::extract_text_from_events(
|
||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = request.into_google(model.id().into());
|
||||
async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client.request_stream(proto::QueryLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Google as i32,
|
||||
kind: proto::LanguageModelRequestKind::Complete as i32,
|
||||
request,
|
||||
});
|
||||
let chunks = response.await?;
|
||||
Ok(google_ai::extract_text_from_events(
|
||||
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
|
||||
)
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let request = proto::CompleteWithLanguageModel {
|
||||
model: self.id.0.to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
self.client
|
||||
.request_stream(request)
|
||||
.map_ok(|stream| {
|
||||
stream
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue