language_models: Fetch Zed models from the server (#31316)
This PR updates the Zed LLM provider to fetch the available models from the server instead of hard-coding them in the binary. Release Notes: - Updated the Zed provider to fetch the list of available language models from the server.
This commit is contained in:
parent
172e0df2d8
commit
685933b5c8
7 changed files with 191 additions and 201 deletions
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
|
@ -524,7 +524,6 @@ jobs:
|
|||
APPLE_NOTARIZATION_KEY_ID: ${{ secrets.APPLE_NOTARIZATION_KEY_ID }}
|
||||
APPLE_NOTARIZATION_ISSUER_ID: ${{ secrets.APPLE_NOTARIZATION_ISSUER_ID }}
|
||||
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
|
||||
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
|
||||
steps:
|
||||
|
@ -611,7 +610,6 @@ jobs:
|
|||
needs: [linux_tests]
|
||||
env:
|
||||
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
|
||||
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
|
||||
steps:
|
||||
|
@ -669,7 +667,6 @@ jobs:
|
|||
needs: [linux_tests]
|
||||
env:
|
||||
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
|
||||
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
|
||||
steps:
|
||||
|
@ -734,7 +731,6 @@ jobs:
|
|||
runs-on: ${{ matrix.system.runner }}
|
||||
env:
|
||||
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
|
|
3
.github/workflows/release_nightly.yml
vendored
3
.github/workflows/release_nightly.yml
vendored
|
@ -68,7 +68,6 @@ jobs:
|
|||
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
|
||||
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
|
||||
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Install Node
|
||||
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
|
||||
|
@ -104,7 +103,6 @@ jobs:
|
|||
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
|
||||
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
|
||||
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||
|
@ -144,7 +142,6 @@ jobs:
|
|||
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
|
||||
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
|
||||
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||
|
|
5
Cargo.lock
generated
5
Cargo.lock
generated
|
@ -8820,7 +8820,6 @@ dependencies = [
|
|||
"credentials_provider",
|
||||
"deepseek",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"google_ai",
|
||||
|
@ -19890,9 +19889,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zed_llm_client"
|
||||
version = "0.8.2"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9be71e2f9b271e1eb8eb3e0d986075e770d1a0a299fb036abc3f1fc13a2fa7eb"
|
||||
checksum = "22a8b9575b215536ed8ad254ba07171e4e13bd029eda3b54cca4b184d2768050"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"serde",
|
||||
|
|
|
@ -616,7 +616,7 @@ wasmtime-wasi = "29"
|
|||
which = "6.0.0"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.8.2"
|
||||
zed_llm_client = "0.8.3"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
|
|
|
@ -244,25 +244,6 @@ pub trait LanguageModel: Send + Sync {
|
|||
|
||||
/// Returns whether this model supports "max mode";
|
||||
fn supports_max_mode(&self) -> bool {
|
||||
if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {
|
||||
return false;
|
||||
}
|
||||
|
||||
const MAX_MODE_CAPABLE_MODELS: &[CloudModel] = &[
|
||||
CloudModel::Anthropic(anthropic::Model::ClaudeOpus4),
|
||||
CloudModel::Anthropic(anthropic::Model::ClaudeOpus4Thinking),
|
||||
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4),
|
||||
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4Thinking),
|
||||
CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
|
||||
CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
|
||||
];
|
||||
|
||||
for model in MAX_MODE_CAPABLE_MODELS {
|
||||
if self.id().0 == model.id() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@ credentials_provider.workspace = true
|
|||
copilot.workspace = true
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
google_ai = { workspace = true, features = ["schemars"] }
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use collections::BTreeMap;
|
||||
use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag};
|
||||
use futures::{
|
||||
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
|
||||
};
|
||||
|
@ -11,7 +9,7 @@ use gpui::{
|
|||
};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||
use language_model::{
|
||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
|
||||
|
@ -26,45 +24,30 @@ use proto::Plan;
|
|||
use release_channel::AppVersion;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use settings::SettingsStore;
|
||||
use smol::Timer;
|
||||
use smol::io::{AsyncReadExt, BufReader};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::{
|
||||
sync::{Arc, LazyLock},
|
||||
time::Duration,
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use ui::{TintColor, prelude::*};
|
||||
use util::{ResultExt as _, maybe};
|
||||
use zed_llm_client::{
|
||||
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
|
||||
CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
|
||||
SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME,
|
||||
ZED_VERSION_HEADER_NAME,
|
||||
ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
|
||||
};
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
|
||||
use crate::provider::google::{GoogleEventMapper, into_google};
|
||||
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
|
||||
|
||||
pub const PROVIDER_NAME: &str = "Zed";
|
||||
|
||||
const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
|
||||
option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
|
||||
|
||||
fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
|
||||
static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
|
||||
.map(|json| serde_json::from_str(json).unwrap())
|
||||
.unwrap_or_default()
|
||||
});
|
||||
ADDITIONAL_MODELS.as_slice()
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct ZedDotDevSettings {
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
|
@ -137,6 +120,11 @@ pub struct State {
|
|||
user_store: Entity<UserStore>,
|
||||
status: client::Status,
|
||||
accept_terms: Option<Task<Result<()>>>,
|
||||
models: Vec<Arc<zed_llm_client::LanguageModel>>,
|
||||
default_model: Option<Arc<zed_llm_client::LanguageModel>>,
|
||||
default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
|
||||
recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
|
||||
_fetch_models_task: Task<()>,
|
||||
_settings_subscription: Subscription,
|
||||
_llm_token_subscription: Subscription,
|
||||
}
|
||||
|
@ -156,6 +144,72 @@ impl State {
|
|||
user_store,
|
||||
status,
|
||||
accept_terms: None,
|
||||
models: Vec::new(),
|
||||
default_model: None,
|
||||
default_fast_model: None,
|
||||
recommended_models: Vec::new(),
|
||||
_fetch_models_task: cx.spawn(async move |this, cx| {
|
||||
maybe!(async move {
|
||||
let (client, llm_api_token) = this
|
||||
.read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
|
||||
|
||||
loop {
|
||||
let status = this.read_with(cx, |this, _cx| this.status)?;
|
||||
if matches!(status, client::Status::Connected { .. }) {
|
||||
break;
|
||||
}
|
||||
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(100))
|
||||
.await;
|
||||
}
|
||||
|
||||
let response = Self::fetch_models(client, llm_api_token).await?;
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
let mut models = Vec::new();
|
||||
|
||||
for model in response.models {
|
||||
models.push(Arc::new(model.clone()));
|
||||
|
||||
// Right now we represent thinking variants of models as separate models on the client,
|
||||
// so we need to insert variants for any model that supports thinking.
|
||||
if model.supports_thinking {
|
||||
models.push(Arc::new(zed_llm_client::LanguageModel {
|
||||
id: zed_llm_client::LanguageModelId(
|
||||
format!("{}-thinking", model.id).into(),
|
||||
),
|
||||
display_name: format!("{} Thinking", model.display_name),
|
||||
..model
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
this.default_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_model)
|
||||
.cloned();
|
||||
this.default_fast_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_fast_model)
|
||||
.cloned();
|
||||
this.recommended_models = response
|
||||
.recommended_models
|
||||
.iter()
|
||||
.filter_map(|id| models.iter().find(|model| &model.id == id))
|
||||
.cloned()
|
||||
.collect();
|
||||
this.models = models;
|
||||
cx.notify();
|
||||
})
|
||||
})??;
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.await
|
||||
.context("failed to fetch Zed models")
|
||||
.log_err();
|
||||
}),
|
||||
_settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
|
@ -208,6 +262,37 @@ impl State {
|
|||
})
|
||||
}));
|
||||
}
|
||||
|
||||
async fn fetch_models(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
) -> Result<ListModelsResponse> {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(AsyncBody::empty())?;
|
||||
let mut response = http_client
|
||||
.send(request)
|
||||
.await
|
||||
.context("failed to send list models request")?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Ok(serde_json::from_str(&body)?);
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
anyhow::bail!(
|
||||
"error listing models.\nStatus: {:?}\nBody: {body}",
|
||||
response.status(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CloudLanguageModelProvider {
|
||||
|
@ -242,11 +327,11 @@ impl CloudLanguageModelProvider {
|
|||
|
||||
fn create_language_model(
|
||||
&self,
|
||||
model: CloudModel,
|
||||
model: Arc<zed_llm_client::LanguageModel>,
|
||||
llm_api_token: LlmApiToken,
|
||||
) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(CloudLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
id: LanguageModelId(SharedString::from(model.id.0.clone())),
|
||||
model,
|
||||
llm_api_token: llm_api_token.clone(),
|
||||
client: self.client.clone(),
|
||||
|
@ -277,121 +362,35 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
|||
}
|
||||
|
||||
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
let default_model = self.state.read(cx).default_model.clone()?;
|
||||
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||
let model = CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4);
|
||||
Some(self.create_language_model(model, llm_api_token))
|
||||
Some(self.create_language_model(default_model, llm_api_token))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
|
||||
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||
let model = CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet);
|
||||
Some(self.create_language_model(model, llm_api_token))
|
||||
Some(self.create_language_model(default_fast_model, llm_api_token))
|
||||
}
|
||||
|
||||
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||
[
|
||||
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4),
|
||||
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4Thinking),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|model| self.create_language_model(model, llm_api_token.clone()))
|
||||
.collect()
|
||||
self.state
|
||||
.read(cx)
|
||||
.recommended_models
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|model| self.create_language_model(model, llm_api_token.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
if cx.is_staff() {
|
||||
for model in anthropic::Model::iter() {
|
||||
if !matches!(model, anthropic::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), CloudModel::Anthropic(model));
|
||||
}
|
||||
}
|
||||
for model in open_ai::Model::iter() {
|
||||
if !matches!(model, open_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), CloudModel::OpenAi(model));
|
||||
}
|
||||
}
|
||||
for model in google_ai::Model::iter() {
|
||||
if !matches!(model, google_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), CloudModel::Google(model));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models.insert(
|
||||
anthropic::Model::Claude3_5Sonnet.id().to_string(),
|
||||
CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
|
||||
);
|
||||
models.insert(
|
||||
anthropic::Model::Claude3_7Sonnet.id().to_string(),
|
||||
CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
|
||||
);
|
||||
models.insert(
|
||||
anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
|
||||
CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
|
||||
);
|
||||
models.insert(
|
||||
anthropic::Model::ClaudeSonnet4.id().to_string(),
|
||||
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4),
|
||||
);
|
||||
models.insert(
|
||||
anthropic::Model::ClaudeSonnet4Thinking.id().to_string(),
|
||||
CloudModel::Anthropic(anthropic::Model::ClaudeSonnet4Thinking),
|
||||
);
|
||||
}
|
||||
|
||||
let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
|
||||
zed_cloud_provider_additional_models()
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
|
||||
// Override with available models from settings
|
||||
for model in AllLanguageModelSettings::get_global(cx)
|
||||
.zed_dot_dev
|
||||
.available_models
|
||||
.iter()
|
||||
.chain(llm_closed_beta_models)
|
||||
.cloned()
|
||||
{
|
||||
let model = match model.provider {
|
||||
AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
tool_override: model.tool_override.clone(),
|
||||
cache_configuration: model.cache_configuration.as_ref().map(|config| {
|
||||
anthropic::AnthropicModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
should_speculate: config.should_speculate,
|
||||
min_total_token: config.min_total_token,
|
||||
}
|
||||
}),
|
||||
default_temperature: model.default_temperature,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
extra_beta_headers: model.extra_beta_headers.clone(),
|
||||
mode: model.mode.unwrap_or_default().into(),
|
||||
}),
|
||||
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
}),
|
||||
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
}),
|
||||
};
|
||||
models.insert(model.id().to_string(), model.clone());
|
||||
}
|
||||
|
||||
let llm_api_token = self.state.read(cx).llm_api_token.clone();
|
||||
models
|
||||
.into_values()
|
||||
self.state
|
||||
.read(cx)
|
||||
.models
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|model| self.create_language_model(model, llm_api_token.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
@ -522,7 +521,7 @@ fn render_accept_terms(
|
|||
|
||||
pub struct CloudLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: CloudModel,
|
||||
model: Arc<zed_llm_client::LanguageModel>,
|
||||
llm_api_token: LlmApiToken,
|
||||
client: Arc<Client>,
|
||||
request_limiter: RateLimiter,
|
||||
|
@ -668,7 +667,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
LanguageModelName::from(self.model.display_name.clone())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
|
@ -680,19 +679,11 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
match self.model {
|
||||
CloudModel::Anthropic(_) => true,
|
||||
CloudModel::Google(_) => true,
|
||||
CloudModel::OpenAi(_) => true,
|
||||
}
|
||||
self.model.supports_tools
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
match self.model {
|
||||
CloudModel::Anthropic(_) => true,
|
||||
CloudModel::Google(_) => true,
|
||||
CloudModel::OpenAi(_) => false,
|
||||
}
|
||||
self.model.supports_images
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
|
@ -703,30 +694,41 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
fn supports_max_mode(&self) -> bool {
|
||||
self.model.supports_max_mode
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("zed.dev/{}", self.model.id())
|
||||
format!("zed.dev/{}", self.model.id)
|
||||
}
|
||||
|
||||
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
||||
self.model.tool_input_format()
|
||||
match self.model.provider {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic
|
||||
| zed_llm_client::LanguageModelProvider::OpenAi => {
|
||||
LanguageModelToolSchemaFormat::JsonSchema
|
||||
}
|
||||
zed_llm_client::LanguageModelProvider::Google => {
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
self.model.max_token_count
|
||||
}
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
model
|
||||
.cache_configuration()
|
||||
.map(|cache| LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: cache.max_cache_anchors,
|
||||
should_speculate: cache.should_speculate,
|
||||
min_total_token: cache.min_total_token,
|
||||
})
|
||||
match &self.model.provider {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic => {
|
||||
Some(LanguageModelCacheConfiguration {
|
||||
min_total_token: 2_048,
|
||||
should_speculate: true,
|
||||
max_cache_anchors: 4,
|
||||
})
|
||||
}
|
||||
CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
|
||||
zed_llm_client::LanguageModelProvider::OpenAi
|
||||
| zed_llm_client::LanguageModelProvider::Google => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -735,13 +737,19 @@ impl LanguageModel for CloudLanguageModel {
|
|||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match self.model.clone() {
|
||||
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
|
||||
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
||||
CloudModel::Google(model) => {
|
||||
match self.model.provider {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
|
||||
zed_llm_client::LanguageModelProvider::OpenAi => {
|
||||
let model = match open_ai::Model::from_id(&self.model.id.0) {
|
||||
Ok(model) => model,
|
||||
Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
|
||||
};
|
||||
count_open_ai_tokens(request, model, cx)
|
||||
}
|
||||
zed_llm_client::LanguageModelProvider::Google => {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let model_id = model.id().to_string();
|
||||
let model_id = self.model.id.to_string();
|
||||
let generate_content_request = into_google(request, model_id.clone());
|
||||
async move {
|
||||
let http_client = &client.http_client();
|
||||
|
@ -803,14 +811,20 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let prompt_id = request.prompt_id.clone();
|
||||
let mode = request.mode;
|
||||
let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
match self.model.provider {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic => {
|
||||
let request = into_anthropic(
|
||||
request,
|
||||
model.request_id().into(),
|
||||
model.default_temperature(),
|
||||
model.max_output_tokens(),
|
||||
model.mode(),
|
||||
self.model.id.to_string(),
|
||||
1.0,
|
||||
self.model.max_output_tokens as u32,
|
||||
if self.model.id.0.ends_with("-thinking") {
|
||||
AnthropicModelMode::Thinking {
|
||||
budget_tokens: Some(4_096),
|
||||
}
|
||||
} else {
|
||||
AnthropicModelMode::Default
|
||||
},
|
||||
);
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
|
@ -862,9 +876,13 @@ impl LanguageModel for CloudLanguageModel {
|
|||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
zed_llm_client::LanguageModelProvider::OpenAi => {
|
||||
let client = self.client.clone();
|
||||
let request = into_open_ai(request, model, model.max_output_tokens());
|
||||
let model = match open_ai::Model::from_id(&self.model.id.0) {
|
||||
Ok(model) => model,
|
||||
Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
|
||||
};
|
||||
let request = into_open_ai(request, &model, None);
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let PerformLlmCompletionResponse {
|
||||
|
@ -899,9 +917,9 @@ impl LanguageModel for CloudLanguageModel {
|
|||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
CloudModel::Google(model) => {
|
||||
zed_llm_client::LanguageModelProvider::Google => {
|
||||
let client = self.client.clone();
|
||||
let request = into_google(request, model.id().into());
|
||||
let request = into_google(request, self.model.id.to_string());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let PerformLlmCompletionResponse {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue