language_models: Fetch Zed models from the server (#31316)

This PR updates the Zed LLM provider to fetch the available models from
the server instead of hard-coding them in the binary.

Release Notes:

- Updated the Zed provider to fetch the list of available language
models from the server.
This commit is contained in:
Marshall Bowers 2025-05-23 19:00:35 -04:00 committed by GitHub
parent 172e0df2d8
commit 685933b5c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 191 additions and 201 deletions

View file

@ -524,7 +524,6 @@ jobs:
APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }} APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }}
APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }} APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }}
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
steps: steps:
@ -611,7 +610,6 @@ jobs:
needs: [linux_tests] needs: [linux_tests]
env: env:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
steps: steps:
@ -669,7 +667,6 @@ jobs:
needs: [linux_tests] needs: [linux_tests]
env: env:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
steps: steps:
@ -734,7 +731,6 @@ jobs:
runs-on: ${{ matrix.system.runner }} runs-on: ${{ matrix.system.runner }}
env: env:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on
steps: steps:
- name: Checkout repo - name: Checkout repo

View file

@ -68,7 +68,6 @@ jobs:
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
steps: steps:
- name: Install Node - name: Install Node
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4 uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
@ -104,7 +103,6 @@ jobs:
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
steps: steps:
- name: Checkout repo - name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
@ -144,7 +142,6 @@ jobs:
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }} DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }} DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }} ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
steps: steps:
- name: Checkout repo - name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4

5
Cargo.lock generated
View file

@ -8820,7 +8820,6 @@ dependencies = [
"credentials_provider", "credentials_provider",
"deepseek", "deepseek",
"editor", "editor",
"feature_flags",
"fs", "fs",
"futures 0.3.31", "futures 0.3.31",
"google_ai", "google_ai",
@ -19890,9 +19889,9 @@ dependencies = [
[[package]] [[package]]
name = "zed_llm_client" name = "zed_llm_client"
version = "0.8.2" version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9be71e2f9b271e1eb8eb3e0d986075e770d1a0a299fb036abc3f1fc13a2fa7eb" checksum = "22a8b9575b215536ed8ad254ba07171e4e13bd029eda3b54cca4b184d2768050"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"serde", "serde",

View file

@ -616,7 +616,7 @@ wasmtime-wasi = "29"
which = "6.0.0" which = "6.0.0"
wit-component = "0.221" wit-component = "0.221"
workspace-hack = "0.1.0" workspace-hack = "0.1.0"
zed_llm_client = "0.8.2" zed_llm_client = "0.8.3"
zstd = "0.11" zstd = "0.11"
[workspace.dependencies.async-stripe] [workspace.dependencies.async-stripe]

View file

@ -244,25 +244,6 @@ pub trait LanguageModel: Send + Sync {
/// Returns whether this model supports "max mode"; /// Returns whether this model supports "max mode";
fn supports_max_mode(&self) -> bool { fn supports_max_mode(&self) -> bool {
if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {
return false;
}
const MAX_MODE_CAPABLE_MODELS: &[CloudModel] = &[
CloudModel::Anthropic(anthropic::Model::ClaudeOpus4),
CloudModel::Anthropic(anthropic::Model::ClaudeOpus4Thinking),
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4),
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4Thinking),
CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
];
for model in MAX_MODE_CAPABLE_MODELS {
if self.id().0 == model.id() {
return true;
}
}
false false
} }

View file

@ -26,7 +26,6 @@ credentials_provider.workspace = true
copilot.workspace = true copilot.workspace = true
deepseek = { workspace = true, features = ["schemars"] } deepseek = { workspace = true, features = ["schemars"] }
editor.workspace = true editor.workspace = true
feature_flags.workspace = true
fs.workspace = true fs.workspace = true
futures.workspace = true futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] } google_ai = { workspace = true, features = ["schemars"] }

View file

@ -1,8 +1,6 @@
use anthropic::{AnthropicModelMode, parse_prompt_too_long}; use anthropic::{AnthropicModelMode, parse_prompt_too_long};
use anyhow::{Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use client::{Client, UserStore, zed_urls}; use client::{Client, UserStore, zed_urls};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag};
use futures::{ use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
}; };
@ -11,7 +9,7 @@ use gpui::{
}; };
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{ use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
@ -26,45 +24,30 @@ use proto::Plan;
use release_channel::AppVersion; use release_channel::AppVersion;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Serialize, de::DeserializeOwned};
use settings::{Settings, SettingsStore}; use settings::SettingsStore;
use smol::Timer; use smol::Timer;
use smol::io::{AsyncReadExt, BufReader}; use smol::io::{AsyncReadExt, BufReader};
use std::pin::Pin; use std::pin::Pin;
use std::str::FromStr as _; use std::str::FromStr as _;
use std::{ use std::sync::Arc;
sync::{Arc, LazyLock}, use std::time::Duration;
time::Duration,
};
use strum::IntoEnumIterator;
use thiserror::Error; use thiserror::Error;
use ui::{TintColor, prelude::*}; use ui::{TintColor, prelude::*};
use util::{ResultExt as _, maybe};
use zed_llm_client::{ use zed_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody, CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
ZED_VERSION_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
}; };
use crate::AllLanguageModelSettings;
use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
use crate::provider::google::{GoogleEventMapper, into_google}; use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai}; use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
pub const PROVIDER_NAME: &str = "Zed"; pub const PROVIDER_NAME: &str = "Zed";
const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
.map(|json| serde_json::from_str(json).unwrap())
.unwrap_or_default()
});
ADDITIONAL_MODELS.as_slice()
}
#[derive(Default, Clone, Debug, PartialEq)] #[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings { pub struct ZedDotDevSettings {
pub available_models: Vec<AvailableModel>, pub available_models: Vec<AvailableModel>,
@ -137,6 +120,11 @@ pub struct State {
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
status: client::Status, status: client::Status,
accept_terms: Option<Task<Result<()>>>, accept_terms: Option<Task<Result<()>>>,
models: Vec<Arc<zed_llm_client::LanguageModel>>,
default_model: Option<Arc<zed_llm_client::LanguageModel>>,
default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
_fetch_models_task: Task<()>,
_settings_subscription: Subscription, _settings_subscription: Subscription,
_llm_token_subscription: Subscription, _llm_token_subscription: Subscription,
} }
@ -156,6 +144,72 @@ impl State {
user_store, user_store,
status, status,
accept_terms: None, accept_terms: None,
models: Vec::new(),
default_model: None,
default_fast_model: None,
recommended_models: Vec::new(),
_fetch_models_task: cx.spawn(async move |this, cx| {
maybe!(async move {
let (client, llm_api_token) = this
.read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
loop {
let status = this.read_with(cx, |this, _cx| this.status)?;
if matches!(status, client::Status::Connected { .. }) {
break;
}
cx.background_executor()
.timer(Duration::from_millis(100))
.await;
}
let response = Self::fetch_models(client, llm_api_token).await?;
cx.update(|cx| {
this.update(cx, |this, cx| {
let mut models = Vec::new();
for model in response.models {
models.push(Arc::new(model.clone()));
// Right now we represent thinking variants of models as separate models on the client,
// so we need to insert variants for any model that supports thinking.
if model.supports_thinking {
models.push(Arc::new(zed_llm_client::LanguageModel {
id: zed_llm_client::LanguageModelId(
format!("{}-thinking", model.id).into(),
),
display_name: format!("{} Thinking", model.display_name),
..model
}));
}
}
this.default_model = models
.iter()
.find(|model| model.id == response.default_model)
.cloned();
this.default_fast_model = models
.iter()
.find(|model| model.id == response.default_fast_model)
.cloned();
this.recommended_models = response
.recommended_models
.iter()
.filter_map(|id| models.iter().find(|model| &model.id == id))
.cloned()
.collect();
this.models = models;
cx.notify();
})
})??;
anyhow::Ok(())
})
.await
.context("failed to fetch Zed models")
.log_err();
}),
_settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| { _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify(); cx.notify();
}), }),
@ -208,6 +262,37 @@ impl State {
}) })
})); }));
} }
async fn fetch_models(
client: Arc<Client>,
llm_api_token: LlmApiToken,
) -> Result<ListModelsResponse> {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
let request = http_client::Request::builder()
.method(Method::GET)
.uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
.header("Authorization", format!("Bearer {token}"))
.body(AsyncBody::empty())?;
let mut response = http_client
.send(request)
.await
.context("failed to send list models request")?;
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"error listing models.\nStatus: {:?}\nBody: {body}",
response.status(),
);
}
}
} }
impl CloudLanguageModelProvider { impl CloudLanguageModelProvider {
@ -242,11 +327,11 @@ impl CloudLanguageModelProvider {
fn create_language_model( fn create_language_model(
&self, &self,
model: CloudModel, model: Arc<zed_llm_client::LanguageModel>,
llm_api_token: LlmApiToken, llm_api_token: LlmApiToken,
) -> Arc<dyn LanguageModel> { ) -> Arc<dyn LanguageModel> {
Arc::new(CloudLanguageModel { Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()), id: LanguageModelId(SharedString::from(model.id.0.clone())),
model, model,
llm_api_token: llm_api_token.clone(), llm_api_token: llm_api_token.clone(),
client: self.client.clone(), client: self.client.clone(),
@ -277,121 +362,35 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
} }
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
let default_model = self.state.read(cx).default_model.clone()?;
let llm_api_token = self.state.read(cx).llm_api_token.clone(); let llm_api_token = self.state.read(cx).llm_api_token.clone();
let model = CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4); Some(self.create_language_model(default_model, llm_api_token))
Some(self.create_language_model(model, llm_api_token))
} }
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
let llm_api_token = self.state.read(cx).llm_api_token.clone(); let llm_api_token = self.state.read(cx).llm_api_token.clone();
let model = CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet); Some(self.create_language_model(default_fast_model, llm_api_token))
Some(self.create_language_model(model, llm_api_token))
} }
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let llm_api_token = self.state.read(cx).llm_api_token.clone(); let llm_api_token = self.state.read(cx).llm_api_token.clone();
[ self.state
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4), .read(cx)
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4Thinking), .recommended_models
] .iter()
.into_iter() .cloned()
.map(|model| self.create_language_model(model, llm_api_token.clone())) .map(|model| self.create_language_model(model, llm_api_token.clone()))
.collect() .collect()
} }
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
if cx.is_staff() {
for model in anthropic::Model::iter() {
if !matches!(model, anthropic::Model::Custom { .. }) {
models.insert(model.id().to_string(), CloudModel::Anthropic(model));
}
}
for model in open_ai::Model::iter() {
if !matches!(model, open_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), CloudModel::OpenAi(model));
}
}
for model in google_ai::Model::iter() {
if !matches!(model, google_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), CloudModel::Google(model));
}
}
} else {
models.insert(
anthropic::Model::Claude3_5Sonnet.id().to_string(),
CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
);
models.insert(
anthropic::Model::Claude3_7Sonnet.id().to_string(),
CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
);
models.insert(
anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
);
models.insert(
anthropic::Model::ClaudeSonnet4.id().to_string(),
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4),
);
models.insert(
anthropic::Model::ClaudeSonnet4Thinking.id().to_string(),
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4Thinking),
);
}
let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
zed_cloud_provider_additional_models()
} else {
&[]
};
// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.zed_dot_dev
.available_models
.iter()
.chain(llm_closed_beta_models)
.cloned()
{
let model = match model.provider {
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
cache_configuration: model.cache_configuration.as_ref().map(|config| {
anthropic::AnthropicModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
should_speculate: config.should_speculate,
min_total_token: config.min_total_token,
}
}),
default_temperature: model.default_temperature,
max_output_tokens: model.max_output_tokens,
extra_beta_headers: model.extra_beta_headers.clone(),
mode: model.mode.unwrap_or_default().into(),
}),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
}),
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
}),
};
models.insert(model.id().to_string(), model.clone());
}
let llm_api_token = self.state.read(cx).llm_api_token.clone(); let llm_api_token = self.state.read(cx).llm_api_token.clone();
models self.state
.into_values() .read(cx)
.models
.iter()
.cloned()
.map(|model| self.create_language_model(model, llm_api_token.clone())) .map(|model| self.create_language_model(model, llm_api_token.clone()))
.collect() .collect()
} }
@ -522,7 +521,7 @@ fn render_accept_terms(
pub struct CloudLanguageModel { pub struct CloudLanguageModel {
id: LanguageModelId, id: LanguageModelId,
model: CloudModel, model: Arc<zed_llm_client::LanguageModel>,
llm_api_token: LlmApiToken, llm_api_token: LlmApiToken,
client: Arc<Client>, client: Arc<Client>,
request_limiter: RateLimiter, request_limiter: RateLimiter,
@ -668,7 +667,7 @@ impl LanguageModel for CloudLanguageModel {
} }
fn name(&self) -> LanguageModelName { fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string()) LanguageModelName::from(self.model.display_name.clone())
} }
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
@ -680,19 +679,11 @@ impl LanguageModel for CloudLanguageModel {
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
match self.model { self.model.supports_tools
CloudModel::Anthropic(_) => true,
CloudModel::Google(_) => true,
CloudModel::OpenAi(_) => true,
}
} }
fn supports_images(&self) -> bool { fn supports_images(&self) -> bool {
match self.model { self.model.supports_images
CloudModel::Anthropic(_) => true,
CloudModel::Google(_) => true,
CloudModel::OpenAi(_) => false,
}
} }
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
@ -703,30 +694,41 @@ impl LanguageModel for CloudLanguageModel {
} }
} }
fn supports_max_mode(&self) -> bool {
self.model.supports_max_mode
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("zed.dev/{}", self.model.id()) format!("zed.dev/{}", self.model.id)
} }
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
self.model.tool_input_format() match self.model.provider {
zed_llm_client::LanguageModelProvider::Anthropic
| zed_llm_client::LanguageModelProvider::OpenAi => {
LanguageModelToolSchemaFormat::JsonSchema
}
zed_llm_client::LanguageModelProvider::Google => {
LanguageModelToolSchemaFormat::JsonSchemaSubset
}
}
} }
fn max_token_count(&self) -> usize { fn max_token_count(&self) -> usize {
self.model.max_token_count() self.model.max_token_count
} }
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> { fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
match &self.model { match &self.model.provider {
CloudModel::Anthropic(model) => { zed_llm_client::LanguageModelProvider::Anthropic => {
model Some(LanguageModelCacheConfiguration {
.cache_configuration() min_total_token: 2_048,
.map(|cache| LanguageModelCacheConfiguration { should_speculate: true,
max_cache_anchors: cache.max_cache_anchors, max_cache_anchors: 4,
should_speculate: cache.should_speculate,
min_total_token: cache.min_total_token,
}) })
} }
CloudModel::OpenAi(_) | CloudModel::Google(_) => None, zed_llm_client::LanguageModelProvider::OpenAi
| zed_llm_client::LanguageModelProvider::Google => None,
} }
} }
@ -735,13 +737,19 @@ impl LanguageModel for CloudLanguageModel {
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &App, cx: &App,
) -> BoxFuture<'static, Result<usize>> { ) -> BoxFuture<'static, Result<usize>> {
match self.model.clone() { match self.model.provider {
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx), zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx), zed_llm_client::LanguageModelProvider::OpenAi => {
CloudModel::Google(model) => { let model = match open_ai::Model::from_id(&self.model.id.0) {
Ok(model) => model,
Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
};
count_open_ai_tokens(request, model, cx)
}
zed_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone(); let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let model_id = model.id().to_string(); let model_id = self.model.id.to_string();
let generate_content_request = into_google(request, model_id.clone()); let generate_content_request = into_google(request, model_id.clone());
async move { async move {
let http_client = &client.http_client(); let http_client = &client.http_client();
@ -803,14 +811,20 @@ impl LanguageModel for CloudLanguageModel {
let prompt_id = request.prompt_id.clone(); let prompt_id = request.prompt_id.clone();
let mode = request.mode; let mode = request.mode;
let app_version = cx.update(|cx| AppVersion::global(cx)).ok(); let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
match &self.model { match self.model.provider {
CloudModel::Anthropic(model) => { zed_llm_client::LanguageModelProvider::Anthropic => {
let request = into_anthropic( let request = into_anthropic(
request, request,
model.request_id().into(), self.model.id.to_string(),
model.default_temperature(), 1.0,
model.max_output_tokens(), self.model.max_output_tokens as u32,
model.mode(), if self.model.id.0.ends_with("-thinking") {
AnthropicModelMode::Thinking {
budget_tokens: Some(4_096),
}
} else {
AnthropicModelMode::Default
},
); );
let client = self.client.clone(); let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
@ -862,9 +876,13 @@ impl LanguageModel for CloudLanguageModel {
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move { Ok(future.await?.boxed()) }.boxed()
} }
CloudModel::OpenAi(model) => { zed_llm_client::LanguageModelProvider::OpenAi => {
let client = self.client.clone(); let client = self.client.clone();
let request = into_open_ai(request, model, model.max_output_tokens()); let model = match open_ai::Model::from_id(&self.model.id.0) {
Ok(model) => model,
Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
};
let request = into_open_ai(request, &model, None);
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse { let PerformLlmCompletionResponse {
@ -899,9 +917,9 @@ impl LanguageModel for CloudLanguageModel {
}); });
async move { Ok(future.await?.boxed()) }.boxed() async move { Ok(future.await?.boxed()) }.boxed()
} }
CloudModel::Google(model) => { zed_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone(); let client = self.client.clone();
let request = into_google(request, model.id().into()); let request = into_google(request, self.model.id.to_string());
let llm_api_token = self.llm_api_token.clone(); let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse { let PerformLlmCompletionResponse {