Merge branch 'main' into push-trzsxukkukpr

This commit is contained in:
Peter Tripp 2025-07-25 09:41:34 -04:00
commit 2dde3fd58c
No known key found for this signature in database
553 changed files with 37661 additions and 11007 deletions

View file

@ -12,6 +12,7 @@ workspace = true
path = "src/language_models.rs"
[dependencies]
ai_onboarding.workspace = true
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
aws-config = { workspace = true, features = ["behavior-version-latest"] }
@ -25,11 +26,10 @@ client.workspace = true
collections.workspace = true
component.workspace = true
credentials_provider.workspace = true
convert_case.workspace = true
copilot.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
@ -44,6 +44,7 @@ ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
vercel = { workspace = true, features = ["schemars"] }
x_ai = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
proto.workspace = true
release_channel.workspace = true

View file

@ -1,8 +1,10 @@
use std::sync::Arc;
use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
use collections::HashSet;
use gpui::{App, Context, Entity};
use language_model::LanguageModelRegistry;
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
pub mod provider;
@ -18,16 +20,81 @@ use crate::provider::lmstudio::LmStudioLanguageModelProvider;
use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
use crate::provider::open_router::OpenRouterLanguageModelProvider;
use crate::provider::vercel::VercelLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
crate::settings::init(cx);
crate::settings::init_settings(cx);
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
register_language_model_providers(registry, user_store, client, cx);
register_language_model_providers(registry, user_store, client.clone(), cx);
});
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
.openai_compatible
.keys()
.cloned()
.collect::<HashSet<_>>();
registry.update(cx, |registry, cx| {
register_openai_compatible_providers(
registry,
&HashSet::default(),
&openai_compatible_providers,
client.clone(),
cx,
);
});
cx.observe_global::<SettingsStore>(move |cx| {
let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
.openai_compatible
.keys()
.cloned()
.collect::<HashSet<_>>();
if openai_compatible_providers_new != openai_compatible_providers {
registry.update(cx, |registry, cx| {
register_openai_compatible_providers(
registry,
&openai_compatible_providers,
&openai_compatible_providers_new,
client.clone(),
cx,
);
});
openai_compatible_providers = openai_compatible_providers_new;
}
})
.detach();
}
fn register_openai_compatible_providers(
registry: &mut LanguageModelRegistry,
old: &HashSet<Arc<str>>,
new: &HashSet<Arc<str>>,
client: Arc<Client>,
cx: &mut Context<LanguageModelRegistry>,
) {
for provider_id in old {
if !new.contains(provider_id) {
registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
}
}
for provider_id in new {
if !old.contains(provider_id) {
registry.register_provider(
OpenAiCompatibleLanguageModelProvider::new(
provider_id.clone(),
client.http_client(),
cx,
),
cx,
);
}
}
}
fn register_language_model_providers(
@ -81,5 +148,6 @@ fn register_language_model_providers(
VercelLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
}

View file

@ -8,5 +8,7 @@ pub mod lmstudio;
pub mod mistral;
pub mod ollama;
pub mod open_ai;
pub mod open_ai_compatible;
pub mod open_router;
pub mod vercel;
pub mod x_ai;

View file

@ -663,7 +663,9 @@ pub fn into_anthropic(
} else {
Some(anthropic::StringOrContents::String(system_message))
},
thinking: if let AnthropicModelMode::Thinking { budget_tokens } = mode {
thinking: if request.thinking_allowed
&& let AnthropicModelMode::Thinking { budget_tokens } = mode
{
Some(anthropic::Thinking::Enabled { budget_tokens })
} else {
None
@ -1108,6 +1110,7 @@ mod tests {
temperature: None,
tools: vec![],
tool_choice: None,
thinking_allowed: true,
};
let anthropic_request = into_anthropic(

View file

@ -243,7 +243,7 @@ impl State {
pub struct BedrockLanguageModelProvider {
http_client: AwsHttpClient,
handler: tokio::runtime::Handle,
handle: tokio::runtime::Handle,
state: gpui::Entity<State>,
}
@ -258,13 +258,9 @@ impl BedrockLanguageModelProvider {
}),
});
let tokio_handle = Tokio::handle(cx);
let coerced_client = AwsHttpClient::new(http_client.clone(), tokio_handle.clone());
Self {
http_client: coerced_client,
handler: tokio_handle.clone(),
http_client: AwsHttpClient::new(http_client.clone()),
handle: Tokio::handle(cx),
state,
}
}
@ -274,7 +270,7 @@ impl BedrockLanguageModelProvider {
id: LanguageModelId::from(model.id().to_string()),
model,
http_client: self.http_client.clone(),
handler: self.handler.clone(),
handle: self.handle.clone(),
state: self.state.clone(),
client: OnceCell::new(),
request_limiter: RateLimiter::new(4),
@ -375,7 +371,7 @@ struct BedrockModel {
id: LanguageModelId,
model: Model,
http_client: AwsHttpClient,
handler: tokio::runtime::Handle,
handle: tokio::runtime::Handle,
client: OnceCell<BedrockClient>,
state: gpui::Entity<State>,
request_limiter: RateLimiter,
@ -447,7 +443,7 @@ impl BedrockModel {
}
}
let config = self.handler.block_on(config_builder.load());
let config = self.handle.block_on(config_builder.load());
anyhow::Ok(BedrockClient::new(&config))
})
.context("initializing Bedrock client")?;
@ -799,7 +795,9 @@ pub fn into_bedrock(
max_tokens: max_output_tokens,
system: Some(system_message),
tools: Some(tool_config),
thinking: if let BedrockModelMode::Thinking { budget_tokens } = mode {
thinking: if request.thinking_allowed
&& let BedrockModelMode::Thinking { budget_tokens } = mode
{
Some(bedrock::Thinking::Enabled { budget_tokens })
} else {
None

View file

@ -1,8 +1,8 @@
use ai_onboarding::YoungAccountBanner;
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag};
use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
};
@ -137,7 +137,6 @@ impl State {
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
Self {
client: client.clone(),
@ -165,47 +164,10 @@ impl State {
.await;
}
let response = Self::fetch_models(client, llm_api_token, use_cloud).await?;
cx.update(|cx| {
this.update(cx, |this, cx| {
let mut models = Vec::new();
for model in response.models {
models.push(Arc::new(model.clone()));
// Right now we represent thinking variants of models as separate models on the client,
// so we need to insert variants for any model that supports thinking.
if model.supports_thinking {
models.push(Arc::new(zed_llm_client::LanguageModel {
id: zed_llm_client::LanguageModelId(
format!("{}-thinking", model.id).into(),
),
display_name: format!("{} Thinking", model.display_name),
..model
}));
}
}
this.default_model = models
.iter()
.find(|model| model.id == response.default_model)
.cloned();
this.default_fast_model = models
.iter()
.find(|model| model.id == response.default_fast_model)
.cloned();
this.recommended_models = response
.recommended_models
.iter()
.filter_map(|id| models.iter().find(|model| &model.id == id))
.cloned()
.collect();
this.models = models;
cx.notify();
})
})??;
anyhow::Ok(())
let response = Self::fetch_models(client, llm_api_token).await?;
this.update(cx, |this, cx| {
this.update_models(response, cx);
})
})
.await
.context("failed to fetch Zed models")
@ -216,12 +178,15 @@ impl State {
}),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
|this, _listener, _event, cx| {
move |this, _listener, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
cx.spawn(async move |_this, _cx| {
cx.spawn(async move |this, cx| {
llm_api_token.refresh(&client).await?;
anyhow::Ok(())
let response = Self::fetch_models(client, llm_api_token).await?;
this.update(cx, |this, cx| {
this.update_models(response, cx);
})
})
.detach_and_log_err(cx);
},
@ -264,21 +229,51 @@ impl State {
}));
}
fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
let mut models = Vec::new();
for model in response.models {
models.push(Arc::new(model.clone()));
// Right now we represent thinking variants of models as separate models on the client,
// so we need to insert variants for any model that supports thinking.
if model.supports_thinking {
models.push(Arc::new(zed_llm_client::LanguageModel {
id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
display_name: format!("{} Thinking", model.display_name),
..model
}));
}
}
self.default_model = models
.iter()
.find(|model| model.id == response.default_model)
.cloned();
self.default_fast_model = models
.iter()
.find(|model| model.id == response.default_fast_model)
.cloned();
self.recommended_models = response
.recommended_models
.iter()
.filter_map(|id| models.iter().find(|model| &model.id == id))
.cloned()
.collect();
self.models = models;
cx.notify();
}
async fn fetch_models(
client: Arc<Client>,
llm_api_token: LlmApiToken,
use_cloud: bool,
) -> Result<ListModelsResponse> {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
let request = http_client::Request::builder()
.method(Method::GET)
.uri(
http_client
.build_zed_llm_url("/models", &[], use_cloud)?
.as_ref(),
)
.uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
.header("Authorization", format!("Bearer {token}"))
.body(AsyncBody::empty())?;
let mut response = http_client
@ -506,7 +501,7 @@ fn render_accept_terms(
)
.child({
match view_kind {
LanguageModelProviderTosView::PromptEditorPopup => {
LanguageModelProviderTosView::TextThreadPopup => {
button_container.w_full().justify_end()
}
LanguageModelProviderTosView::Configuration => {
@ -542,7 +537,6 @@ impl CloudLanguageModel {
llm_api_token: LlmApiToken,
app_version: Option<SemanticVersion>,
body: CompletionBody,
use_cloud: bool,
) -> Result<PerformLlmCompletionResponse> {
let http_client = &client.http_client();
@ -550,11 +544,9 @@ impl CloudLanguageModel {
let mut refreshed_token = false;
loop {
let request_builder = http_client::Request::builder().method(Method::POST).uri(
http_client
.build_zed_llm_url("/completions", &[], use_cloud)?
.as_ref(),
);
let request_builder = http_client::Request::builder()
.method(Method::POST)
.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
let request_builder = if let Some(app_version) = app_version {
request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
} else {
@ -653,8 +645,62 @@ struct ApiError {
headers: HeaderMap<HeaderValue>,
}
/// Represents error responses from Zed's cloud API.
///
/// Example JSON for an upstream HTTP error:
/// ```json
/// {
/// "code": "upstream_http_error",
/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
/// "upstream_status": 503
/// }
/// ```
#[derive(Debug, serde::Deserialize)]
struct CloudApiError {
code: String,
message: String,
#[serde(default)]
#[serde(deserialize_with = "deserialize_optional_status_code")]
upstream_status: Option<StatusCode>,
#[serde(default)]
retry_after: Option<f64>,
}
fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
where
D: serde::Deserializer<'de>,
{
let opt: Option<u16> = Option::deserialize(deserializer)?;
Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
}
impl From<ApiError> for LanguageModelCompletionError {
fn from(error: ApiError) -> Self {
if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
if cloud_error.code.starts_with("upstream_http_") {
let status = if let Some(status) = cloud_error.upstream_status {
status
} else if cloud_error.code.ends_with("_error") {
error.status
} else {
// If there's a status code in the code string (e.g. "upstream_http_429")
// then use that; otherwise, see if the JSON contains a status code.
cloud_error
.code
.strip_prefix("upstream_http_")
.and_then(|code_str| code_str.parse::<u16>().ok())
.and_then(|code| StatusCode::from_u16(code).ok())
.unwrap_or(error.status)
};
return LanguageModelCompletionError::UpstreamProviderError {
message: cloud_error.message,
status,
retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
};
}
}
let retry_after = None;
LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
@ -781,7 +827,6 @@ impl LanguageModel for CloudLanguageModel {
let model_id = self.model.id.to_string();
let generate_content_request =
into_google(request, model_id.clone(), GoogleModelMode::Default);
let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
async move {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
@ -797,7 +842,7 @@ impl LanguageModel for CloudLanguageModel {
.method(Method::POST)
.uri(
http_client
.build_zed_llm_url("/count_tokens", &[], use_cloud)?
.build_zed_llm_url("/count_tokens", &[])?
.as_ref(),
)
.header("Content-Type", "application/json")
@ -846,9 +891,7 @@ impl LanguageModel for CloudLanguageModel {
let intent = request.intent;
let mode = request.mode;
let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
let use_cloud = cx
.update(|cx| cx.has_flag::<ZedCloudFeatureFlag>())
.unwrap_or(false);
let thinking_allowed = request.thinking_allowed;
match self.model.provider {
zed_llm_client::LanguageModelProvider::Anthropic => {
let request = into_anthropic(
@ -856,7 +899,7 @@ impl LanguageModel for CloudLanguageModel {
self.model.id.to_string(),
1.0,
self.model.max_output_tokens as u64,
if self.model.id.0.ends_with("-thinking") {
if thinking_allowed && self.model.id.0.ends_with("-thinking") {
AnthropicModelMode::Thinking {
budget_tokens: Some(4_096),
}
@ -886,7 +929,6 @@ impl LanguageModel for CloudLanguageModel {
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
},
use_cloud,
)
.await
.map_err(|err| match err.downcast::<ApiError>() {
@ -939,7 +981,6 @@ impl LanguageModel for CloudLanguageModel {
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
},
use_cloud,
)
.await?;
@ -980,7 +1021,6 @@ impl LanguageModel for CloudLanguageModel {
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
},
use_cloud,
)
.await?;
@ -1087,6 +1127,7 @@ struct ZedAiConfiguration {
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
eligible_for_trial: bool,
has_accepted_terms_of_service: bool,
account_too_young: bool,
accept_terms_of_service_in_progress: bool,
accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
@ -1094,89 +1135,98 @@ struct ZedAiConfiguration {
impl RenderOnce for ZedAiConfiguration {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
let young_account_banner = YoungAccountBanner;
let is_pro = self.plan == Some(proto::Plan::ZedPro);
let subscription_text = match (self.plan, self.subscription_period) {
(Some(proto::Plan::ZedPro), Some(_)) => {
"You have access to Zed's hosted LLMs through your Zed Pro subscription."
"You have access to Zed's hosted models through your Pro subscription."
}
(Some(proto::Plan::ZedProTrial), Some(_)) => {
"You have access to Zed's hosted LLMs through your Zed Pro trial."
"You have access to Zed's hosted models through your Pro trial."
}
(Some(proto::Plan::Free), Some(_)) => {
"You have basic access to Zed's hosted LLMs through your Zed Free subscription."
"You have basic access to Zed's hosted models through the Free plan."
}
_ => {
if self.eligible_for_trial {
"Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
"Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
} else {
"Subscribe for access to Zed's hosted LLMs."
"Subscribe for access to Zed's hosted models."
}
}
};
let manage_subscription_buttons = if is_pro {
h_flex().child(
Button::new("manage_settings", "Manage Subscription")
.style(ButtonStyle::Tinted(TintColor::Accent))
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
)
Button::new("manage_settings", "Manage Subscription")
.full_width()
.style(ButtonStyle::Tinted(TintColor::Accent))
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
.into_any_element()
} else if self.plan.is_none() || self.eligible_for_trial {
Button::new("start_trial", "Start 14-day Free Pro Trial")
.full_width()
.style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
.into_any_element()
} else {
h_flex()
.gap_2()
.child(
Button::new("learn_more", "Learn more")
.style(ButtonStyle::Subtle)
.on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)),
)
.child(
Button::new(
"upgrade",
if self.plan.is_none() && self.eligible_for_trial {
"Start Trial"
} else {
"Upgrade"
},
)
.style(ButtonStyle::Subtle)
.color(Color::Accent)
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
)
Button::new("upgrade", "Upgrade to Pro")
.full_width()
.style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
.into_any_element()
};
if self.is_connected {
v_flex()
.gap_3()
.w_full()
.when(!self.has_accepted_terms_of_service, |this| {
this.child(render_accept_terms(
LanguageModelProviderTosView::Configuration,
self.accept_terms_of_service_in_progress,
{
let callback = self.accept_terms_of_service_callback.clone();
move |window, cx| (callback)(window, cx)
},
))
})
.when(self.has_accepted_terms_of_service, |this| {
this.child(subscription_text)
.child(manage_subscription_buttons)
})
} else {
v_flex()
if !self.is_connected {
return v_flex()
.gap_2()
.child(Label::new("Use Zed AI to access hosted language models."))
.child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
.child(
Button::new("sign_in", "Sign In")
Button::new("sign_in", "Sign In to use Zed AI")
.icon_color(Color::Muted)
.icon(IconName::Github)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.full_width()
.on_click({
let callback = self.sign_in_callback.clone();
move |_, window, cx| (callback)(window, cx)
}),
)
);
}
v_flex()
.gap_2()
.w_full()
.when(!self.has_accepted_terms_of_service, |this| {
this.child(render_accept_terms(
LanguageModelProviderTosView::Configuration,
self.accept_terms_of_service_in_progress,
{
let callback = self.accept_terms_of_service_callback.clone();
move |window, cx| (callback)(window, cx)
},
))
})
.map(|this| {
if self.has_accepted_terms_of_service && self.account_too_young {
this.child(young_account_banner).child(
Button::new("upgrade", "Upgrade to Pro")
.style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
.full_width()
.on_click(|_, _, cx| {
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
}),
)
} else if self.has_accepted_terms_of_service {
this.text_sm()
.child(subscription_text)
.child(manage_subscription_buttons)
} else {
this
}
})
.when(self.has_accepted_terms_of_service, |this| this)
}
}
@ -1225,6 +1275,7 @@ impl Render for ConfigurationView {
subscription_period: user_store.subscription_period(),
eligible_for_trial: user_store.trial_started_at().is_none(),
has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
account_too_young: user_store.account_too_young(),
accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
sign_in_callback: self.sign_in_callback.clone(),
@ -1242,6 +1293,7 @@ impl Component for ZedAiConfiguration {
is_connected: bool,
plan: Option<proto::Plan>,
eligible_for_trial: bool,
account_too_young: bool,
has_accepted_terms_of_service: bool,
) -> AnyElement {
ZedAiConfiguration {
@ -1252,6 +1304,7 @@ impl Component for ZedAiConfiguration {
.then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
eligible_for_trial,
has_accepted_terms_of_service,
account_too_young,
accept_terms_of_service_in_progress: false,
accept_terms_of_service_callback: Arc::new(|_, _| {}),
sign_in_callback: Arc::new(|_, _| {}),
@ -1264,33 +1317,188 @@ impl Component for ZedAiConfiguration {
.p_4()
.gap_4()
.children(vec![
single_example("Not connected", configuration(false, None, false, true)),
single_example(
"Not connected",
configuration(false, None, false, false, true),
),
single_example(
"Accept Terms of Service",
configuration(true, None, true, false),
configuration(true, None, true, false, false),
),
single_example(
"No Plan - Not eligible for trial",
configuration(true, None, false, true),
configuration(true, None, false, false, true),
),
single_example(
"No Plan - Eligible for trial",
configuration(true, None, true, true),
configuration(true, None, true, false, true),
),
single_example(
"Free Plan",
configuration(true, Some(proto::Plan::Free), true, true),
configuration(true, Some(proto::Plan::Free), true, false, true),
),
single_example(
"Zed Pro Trial Plan",
configuration(true, Some(proto::Plan::ZedProTrial), true, true),
configuration(true, Some(proto::Plan::ZedProTrial), true, false, true),
),
single_example(
"Zed Pro Plan",
configuration(true, Some(proto::Plan::ZedPro), true, true),
configuration(true, Some(proto::Plan::ZedPro), true, false, true),
),
])
.into_any_element(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use http_client::http::{HeaderMap, StatusCode};
use language_model::LanguageModelCompletionError;
#[test]
fn test_api_error_conversion_with_upstream_http_error() {
// upstream_http_error with 503 status should become ServerOverloaded
let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
let api_error = ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: error_body.to_string(),
headers: HeaderMap::new(),
};
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
assert_eq!(
message,
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
);
}
_ => panic!(
"Expected UpstreamProviderError for upstream 503, got: {:?}",
completion_error
),
}
// upstream_http_error with 500 status should become ApiInternalServerError
let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
let api_error = ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: error_body.to_string(),
headers: HeaderMap::new(),
};
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
assert_eq!(
message,
"Received an error from the OpenAI API: internal server error"
);
}
_ => panic!(
"Expected UpstreamProviderError for upstream 500, got: {:?}",
completion_error
),
}
// upstream_http_error with 429 status should become RateLimitExceeded
let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
let api_error = ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: error_body.to_string(),
headers: HeaderMap::new(),
};
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
assert_eq!(
message,
"Received an error from the Google API: rate limit exceeded"
);
}
_ => panic!(
"Expected UpstreamProviderError for upstream 429, got: {:?}",
completion_error
),
}
// Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
let error_body = "Regular internal server error";
let api_error = ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: error_body.to_string(),
headers: HeaderMap::new(),
};
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider, PROVIDER_NAME);
assert_eq!(message, "Regular internal server error");
}
_ => panic!(
"Expected ApiInternalServerError for regular 500, got: {:?}",
completion_error
),
}
// upstream_http_429 format should be converted to UpstreamProviderError
let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
let api_error = ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: error_body.to_string(),
headers: HeaderMap::new(),
};
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::UpstreamProviderError {
message,
status,
retry_after,
} => {
assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
}
_ => panic!(
"Expected UpstreamProviderError for upstream_http_429, got: {:?}",
completion_error
),
}
// Invalid JSON in error body should fall back to regular error handling
let error_body = "Not JSON at all";
let api_error = ApiError {
status: StatusCode::INTERNAL_SERVER_ERROR,
body: error_body.to_string(),
headers: HeaderMap::new(),
};
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
assert_eq!(provider, PROVIDER_NAME);
}
_ => panic!(
"Expected ApiInternalServerError for invalid JSON, got: {:?}",
completion_error
),
}
}
}

View file

@ -94,6 +94,7 @@ pub struct State {
_subscription: Subscription,
}
const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY";
const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY";
impl State {
@ -151,6 +152,8 @@ impl State {
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) {
(api_key, true)
} else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, &cx)
@ -559,11 +562,11 @@ pub fn into_google(
stop_sequences: Some(request.stop),
max_output_tokens: None,
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
thinking_config: match mode {
GoogleModelMode::Thinking { budget_tokens } => {
thinking_config: match (request.thinking_allowed, mode) {
(true, GoogleModelMode::Thinking { budget_tokens }) => {
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
}
GoogleModelMode::Default => None,
_ => None,
},
top_p: None,
top_k: None,
@ -903,7 +906,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
format!("You can also assign the {GOOGLE_AI_API_KEY_VAR} environment variable and restart Zed."),
format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
)
@ -922,7 +925,7 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {GOOGLE_AI_API_KEY_VAR} environment variable.")
format!("API key set in {GEMINI_API_KEY_VAR} environment variable.")
} else {
"API key configured.".to_string()
})),
@ -935,7 +938,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {GOOGLE_AI_API_KEY_VAR} environment variable.")))
this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)

View file

@ -410,8 +410,20 @@ pub fn into_mistral(
.push_part(mistral::MessagePart::Text { text: text.clone() });
}
MessageContent::RedactedThinking(_) => {}
MessageContent::ToolUse(_) | MessageContent::ToolResult(_) => {
// Tool content is not supported in User messages for Mistral
MessageContent::ToolUse(_) => {
// Tool use is not supported in User messages for Mistral
}
MessageContent::ToolResult(tool_result) => {
let tool_content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => text.to_string(),
LanguageModelToolResultContent::Image(_) => {
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
}
};
messages.push(mistral::RequestMessage::Tool {
content: tool_content,
tool_call_id: tool_result.tool_use_id.to_string(),
});
}
}
}
@ -482,24 +494,6 @@ pub fn into_mistral(
}
}
for message in &request.messages {
for content in &message.content {
if let MessageContent::ToolResult(tool_result) = content {
let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => text.to_string(),
LanguageModelToolResultContent::Image(_) => {
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
}
};
messages.push(mistral::RequestMessage::Tool {
content,
tool_call_id: tool_result.tool_use_id.to_string(),
});
}
}
}
// The Mistral API requires that tool messages be followed by assistant messages,
// not user messages. When we have a tool->user sequence in the conversation,
// we need to insert a placeholder assistant message to maintain proper conversation
@ -911,6 +905,7 @@ mod tests {
intent: None,
mode: None,
stop: vec![],
thinking_allowed: true,
};
let mistral_request = into_mistral(request, "mistral-small-latest".into(), None);
@ -943,6 +938,7 @@ mod tests {
intent: None,
mode: None,
stop: vec![],
thinking_allowed: true,
};
let mistral_request = into_mistral(request, "pixtral-12b-latest".into(), None);

View file

@ -415,7 +415,10 @@ impl OllamaLanguageModel {
temperature: request.temperature.or(Some(1.0)),
..Default::default()
}),
think: self.model.supports_thinking,
think: self
.model
.supports_thinking
.map(|supports_thinking| supports_thinking && request.thinking_allowed),
tools: request.tools.into_iter().map(tool_into_ollama).collect(),
}
}

View file

@ -2,7 +2,6 @@ use anyhow::{Context as _, Result, anyhow};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use fs::Fs;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
@ -18,7 +17,7 @@ use menu;
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore, update_settings_file};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::Arc;
@ -28,7 +27,6 @@ use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
use crate::OpenAiSettingsContent;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
@ -621,26 +619,32 @@ struct RawToolCall {
arguments: String,
}
pub(crate) fn collect_tiktoken_messages(
request: LanguageModelRequest,
) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>()
}
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
model: Model,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
let messages = collect_tiktoken_messages(request);
match model {
Model::Custom { max_tokens, .. } => {
@ -678,7 +682,6 @@ pub fn count_open_ai_tokens(
struct ConfigurationView {
api_key_editor: Entity<SingleLineInput>,
api_url_editor: Entity<SingleLineInput>,
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
}
@ -691,23 +694,6 @@ impl ConfigurationView {
cx,
"sk-000000000000000000000000000000000000000000000000",
)
.label("API key")
});
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
let api_url_editor = cx.new(|cx| {
let input = SingleLineInput::new(window, cx, open_ai::OPEN_AI_API_URL).label("API URL");
if !api_url.is_empty() {
input.editor.update(cx, |editor, cx| {
editor.set_text(&*api_url, window, cx);
});
}
input
});
cx.observe(&state, |_, _, cx| {
@ -735,7 +721,6 @@ impl ConfigurationView {
Self {
api_key_editor,
api_url_editor,
state,
load_credentials_task,
}
@ -783,57 +768,6 @@ impl ConfigurationView {
cx.notify();
}
fn save_api_url(&mut self, cx: &mut Context<Self>) {
let api_url = self
.api_url_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
let current_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
let effective_current_url = if current_url.is_empty() {
open_ai::OPEN_AI_API_URL
} else {
&current_url
};
if !api_url.is_empty() && api_url != effective_current_url {
let fs = <dyn Fs>::global(cx);
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |settings, _| {
if let Some(settings) = settings.openai.as_mut() {
settings.api_url = Some(api_url.clone());
} else {
settings.openai = Some(OpenAiSettingsContent {
api_url: Some(api_url.clone()),
available_models: None,
});
}
});
}
}
fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_url_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
let fs = <dyn Fs>::global(cx);
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
if let Some(settings) = settings.openai.as_mut() {
settings.api_url = None;
}
});
cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}
@ -846,7 +780,6 @@ impl Render for ConfigurationView {
let api_key_section = if self.should_render_editor(cx) {
v_flex()
.on_action(cx.listener(Self::save_api_key))
.child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:"))
.child(
List::new()
@ -910,59 +843,34 @@ impl Render for ConfigurationView {
.into_any()
};
let custom_api_url_set =
AllLanguageModelSettings::get_global(cx).openai.api_url != open_ai::OPEN_AI_API_URL;
let api_url_section = if custom_api_url_set {
h_flex()
.mt_1()
.p_1()
.justify_between()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().background)
.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new("Custom API URL configured.")),
)
.child(
Button::new("reset-api-url", "Reset API URL")
.label_size(LabelSize::Small)
.icon(IconName::Undo)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.on_click(
cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
),
)
.into_any()
} else {
v_flex()
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
this.save_api_url(cx);
cx.notify();
}))
.mt_2()
.pt_2()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.gap_1()
.child(
List::new()
.child(InstructionListItem::text_only(
"Optionally, you can change the base URL for the OpenAI API request.",
))
.child(InstructionListItem::text_only(
"Paste the new API endpoint below and hit enter",
)),
)
.child(self.api_url_editor.clone())
.into_any()
};
let compatible_api_section = h_flex()
.mt_1p5()
.gap_0p5()
.flex_wrap()
.when(self.should_render_editor(cx), |this| {
this.pt_1p5()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
})
.child(
h_flex()
.gap_2()
.child(
Icon::new(IconName::Info)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new("Zed also supports OpenAI-compatible models.")),
)
.child(
Button::new("docs", "Learn More")
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _window, cx| {
cx.open_url("https://zed.dev/docs/ai/configuration#openai-api-compatible")
}),
);
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials…")).into_any()
@ -970,7 +878,7 @@ impl Render for ConfigurationView {
v_flex()
.size_full()
.child(api_key_section)
.child(api_url_section)
.child(compatible_api_section)
.into_any()
}
}
@ -999,6 +907,7 @@ mod tests {
tool_choice: None,
stop: vec![],
temperature: None,
thinking_allowed: true,
};
// Validate that all models are supported by tiktoken-rs

View file

@ -0,0 +1,522 @@
use anyhow::{Context as _, Result, anyhow};
use credentials_provider::CredentialsProvider;
use convert_case::{Case, Casing};
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, RateLimiter,
};
use menu;
use open_ai::{ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use ui::{ElevationIndex, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiCompatibleSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
pub max_completion_tokens: Option<u64>,
}
pub struct OpenAiCompatibleLanguageModelProvider {
id: LanguageModelProviderId,
name: LanguageModelProviderName,
http_client: Arc<dyn HttpClient>,
state: gpui::Entity<State>,
}
pub struct State {
id: Arc<str>,
env_var_name: Arc<str>,
api_key: Option<String>,
api_key_from_env: bool,
settings: OpenAiCompatibleSettings,
_subscription: Subscription,
}
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = self.settings.api_url.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, &cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = self.settings.api_url.clone();
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let env_var_name = self.env_var_name.clone();
let api_url = self.settings.api_url.clone();
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, &cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
impl OpenAiCompatibleLanguageModelProvider {
pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
AllLanguageModelSettings::get_global(cx)
.openai_compatible
.get(id)
}
let state = cx.new(|cx| State {
id: id.clone(),
env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let Some(settings) = resolve_settings(&this.id, cx) else {
return;
};
if &this.settings != settings {
this.settings = settings.clone();
cx.notify();
}
}),
});
Self {
id: id.clone().into(),
name: id.into(),
http_client,
state,
}
}
fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
Arc::new(OpenAiCompatibleLanguageModel {
id: LanguageModelId::from(model.name.clone()),
provider_id: self.id.clone(),
provider_name: self.name.clone(),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
}
impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
self.id.clone()
}
fn name(&self) -> LanguageModelProviderName {
self.name.clone()
}
fn icon(&self) -> IconName {
IconName::AiOpenAiCompat
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.state
.read(cx)
.settings
.available_models
.first()
.map(|model| self.create_language_model(model.clone()))
}
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
None
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
self.state
.read(cx)
.settings
.available_models
.iter()
.map(|model| self.create_language_model(model.clone()))
.collect()
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
}
}
pub struct OpenAiCompatibleLanguageModel {
id: LanguageModelId,
provider_id: LanguageModelProviderId,
provider_name: LanguageModelProviderName,
model: AvailableModel,
state: gpui::Entity<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl OpenAiCompatibleLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
(state.api_key.clone(), state.settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let provider = self.provider_name.clone();
let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
};
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for OpenAiCompatibleLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(
self.model
.display_name
.clone()
.unwrap_or_else(|| self.model.name.clone()),
)
}
fn provider_id(&self) -> LanguageModelProviderId {
self.provider_id.clone()
}
fn provider_name(&self) -> LanguageModelProviderName {
self.provider_name.clone()
}
fn supports_tools(&self) -> bool {
true
}
fn supports_images(&self) -> bool {
false
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto => true,
LanguageModelToolChoice::Any => true,
LanguageModelToolChoice::None => true,
}
}
fn telemetry_id(&self) -> String {
format!("openai/{}", self.model.name)
}
fn max_token_count(&self) -> u64 {
self.model.max_tokens
}
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
let max_token_count = self.max_token_count();
cx.background_spawn(async move {
let messages = super::open_ai::collect_tiktoken_messages(request);
let model = if max_token_count >= 100_000 {
// If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
"gpt-4o"
} else {
// Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
// supported with this tiktoken method
"gpt-4"
};
tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
})
.boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move {
let mapper = OpenAiEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
}
.boxed()
}
}
struct ConfigurationView {
api_key_editor: Entity<SingleLineInput>,
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
}
impl ConfigurationView {
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let api_key_editor = cx.new(|cx| {
SingleLineInput::new(
window,
cx,
"000000000000000000000000000000000000000000000000000",
)
});
cx.observe(&state, |_, _, cx| {
cx.notify();
})
.detach();
let load_credentials_task = Some(cx.spawn_in(window, {
let state = state.clone();
async move |this, cx| {
if let Some(task) = state
.update(cx, |state, cx| state.authenticate(cx))
.log_err()
{
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
api_key_editor,
state,
load_credentials_task,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self
.api_key_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
// Don't proceed if no API key is provided and we're not authenticated
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
return;
}
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let env_var_name = self.state.read(cx).env_var_name.clone();
let api_key_section = if self.should_render_editor(cx) {
v_flex()
.on_action(cx.listener(Self::save_api_key))
.child(Label::new("To use Zed's assistant with an OpenAI compatible provider, you need to add an API key."))
.child(
div()
.pt(DynamicSpacing::Base04.rems(cx))
.child(self.api_key_editor.clone())
)
.child(
Label::new(
format!("You can also assign the {env_var_name} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
)
.into_any()
} else {
h_flex()
.mt_1()
.p_1()
.justify_between()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().background)
.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {env_var_name} environment variable.")
} else {
"API key configured.".to_string()
})),
)
.child(
Button::new("reset-api-key", "Reset API Key")
.label_size(LabelSize::Small)
.icon(IconName::Undo)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
.into_any()
};
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials…")).into_any()
} else {
v_flex().size_full().child(api_key_section).into_any()
}
}
}

View file

@ -376,7 +376,7 @@ impl LanguageModel for OpenRouterLanguageModel {
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
let model_id = self.model.id().trim().to_lowercase();
if model_id.contains("gemini") {
if model_id.contains("gemini") || model_id.contains("grok-4") {
LanguageModelToolSchemaFormat::JsonSchemaSubset
} else {
LanguageModelToolSchemaFormat::JsonSchema
@ -523,7 +523,9 @@ pub fn into_open_router(
None
},
usage: open_router::RequestUsage { include: true },
reasoning: if let OpenRouterModelMode::Thinking { budget_tokens } = model.mode {
reasoning: if request.thinking_allowed
&& let OpenRouterModelMode::Thinking { budget_tokens } = model.mode
{
Some(open_router::Reasoning {
effort: None,
max_tokens: budget_tokens,

View file

@ -0,0 +1,571 @@
use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role,
};
use menu;
use open_ai::ResponseStreamEvent;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use strum::IntoEnumIterator;
use x_ai::Model;
use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "x_ai";
const PROVIDER_NAME: &str = "xAI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct XAiSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
pub max_completion_tokens: Option<u64>,
}
pub struct XAiLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Entity<State>,
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
}
const XAI_API_KEY_VAR: &str = "XAI_API_KEY";
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
let api_url = if settings.api_url.is_empty() {
x_ai::XAI_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, &cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
let api_url = if settings.api_url.is_empty() {
x_ai::XAI_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
let api_url = if settings.api_url.is_empty() {
x_ai::XAI_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, &cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
impl XAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
fn create_language_model(&self, model: x_ai::Model) -> Arc<dyn LanguageModel> {
Arc::new(XAiLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
}
impl LanguageModelProviderState for XAiLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for XAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn icon(&self) -> IconName {
IconName::AiXAi
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(x_ai::Model::default()))
}
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(x_ai::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
for model in x_ai::Model::iter() {
if !matches!(model, x_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), model);
}
}
for model in &AllLanguageModelSettings::get_global(cx)
.x_ai
.available_models
{
models.insert(
model.name.clone(),
x_ai::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
},
);
}
models
.into_values()
.map(|model| self.create_language_model(model))
.collect()
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
}
}
pub struct XAiLanguageModel {
id: LanguageModelId,
model: x_ai::Model,
state: gpui::Entity<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl XAiLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
let api_url = if settings.api_url.is_empty() {
x_ai::XAI_API_URL.to_string()
} else {
settings.api_url.clone()
};
(state.api_key.clone(), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing xAI API Key")?;
let request =
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for XAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn supports_tools(&self) -> bool {
self.model.supports_tool()
}
fn supports_images(&self) -> bool {
self.model.supports_images()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
| LanguageModelToolChoice::Any
| LanguageModelToolChoice::None => true,
}
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
let model_id = self.model.id().trim().to_lowercase();
if model_id.eq(x_ai::Model::Grok4.id()) {
LanguageModelToolSchemaFormat::JsonSchemaSubset
} else {
LanguageModelToolSchemaFormat::JsonSchema
}
}
fn telemetry_id(&self) -> String {
format!("x_ai/{}", self.model.id())
}
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
count_xai_tokens(request, self.model.clone(), cx)
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let request = crate::provider::open_ai::into_open_ai(
request,
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.max_output_tokens(),
);
let completions = self.stream_completion(request, cx);
async move {
let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
}
.boxed()
}
}
pub fn count_xai_tokens(
request: LanguageModelRequest,
model: Model,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
let model_name = if model.max_token_count() >= 100_000 {
"gpt-4o"
} else {
"gpt-4"
};
tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64)
})
.boxed()
}
struct ConfigurationView {
api_key_editor: Entity<SingleLineInput>,
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
}
impl ConfigurationView {
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let api_key_editor = cx.new(|cx| {
SingleLineInput::new(
window,
cx,
"xai-0000000000000000000000000000000000000000000000000",
)
.label("API key")
});
cx.observe(&state, |_, _, cx| {
cx.notify();
})
.detach();
let load_credentials_task = Some(cx.spawn_in(window, {
let state = state.clone();
async move |this, cx| {
if let Some(task) = state
.update(cx, |state, cx| state.authenticate(cx))
.log_err()
{
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
api_key_editor,
state,
load_credentials_task,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self
.api_key_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
// Don't proceed if no API key is provided and we're not authenticated
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
return;
}
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let api_key_section = if self.should_render_editor(cx) {
v_flex()
.on_action(cx.listener(Self::save_api_key))
.child(Label::new("To use Zed's agent with xAI, you need to add an API key. Follow these steps:"))
.child(
List::new()
.child(InstructionListItem::new(
"Create one by visiting",
Some("xAI console"),
Some("https://console.x.ai/team/default/api-keys"),
))
.child(InstructionListItem::text_only(
"Paste your API key below and hit enter to start using the agent",
)),
)
.child(self.api_key_editor.clone())
.child(
Label::new(format!(
"You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed."
))
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(
Label::new("Note that xAI is a custom OpenAI-compatible provider.")
.size(LabelSize::Small)
.color(Color::Muted),
)
.into_any()
} else {
h_flex()
.mt_1()
.p_1()
.justify_between()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().background)
.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {XAI_API_KEY_VAR} environment variable.")
} else {
"API key configured.".to_string()
})),
)
.child(
Button::new("reset-api-key", "Reset API Key")
.label_size(LabelSize::Small)
.icon(IconName::Undo)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
.into_any()
};
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials…")).into_any()
} else {
v_flex().size_full().child(api_key_section).into_any()
}
}
}

View file

@ -1,4 +1,7 @@
use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
use gpui::App;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -15,12 +18,14 @@ use crate::provider::{
mistral::MistralSettings,
ollama::OllamaSettings,
open_ai::OpenAiSettings,
open_ai_compatible::OpenAiCompatibleSettings,
open_router::OpenRouterSettings,
vercel::VercelSettings,
x_ai::XAiSettings,
};
/// Initializes the language model settings.
pub fn init(cx: &mut App) {
pub fn init_settings(cx: &mut App) {
AllLanguageModelSettings::register(cx);
}
@ -28,33 +33,35 @@ pub fn init(cx: &mut App) {
pub struct AllLanguageModelSettings {
pub anthropic: AnthropicSettings,
pub bedrock: AmazonBedrockSettings,
pub ollama: OllamaSettings,
pub openai: OpenAiSettings,
pub open_router: OpenRouterSettings,
pub zed_dot_dev: ZedDotDevSettings,
pub google: GoogleSettings,
pub vercel: VercelSettings,
pub lmstudio: LmStudioSettings,
pub deepseek: DeepSeekSettings,
pub google: GoogleSettings,
pub lmstudio: LmStudioSettings,
pub mistral: MistralSettings,
pub ollama: OllamaSettings,
pub open_router: OpenRouterSettings,
pub openai: OpenAiSettings,
pub openai_compatible: HashMap<Arc<str>, OpenAiCompatibleSettings>,
pub vercel: VercelSettings,
pub x_ai: XAiSettings,
pub zed_dot_dev: ZedDotDevSettings,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AllLanguageModelSettingsContent {
pub anthropic: Option<AnthropicSettingsContent>,
pub bedrock: Option<AmazonBedrockSettingsContent>,
pub ollama: Option<OllamaSettingsContent>,
pub deepseek: Option<DeepseekSettingsContent>,
pub google: Option<GoogleSettingsContent>,
pub lmstudio: Option<LmStudioSettingsContent>,
pub openai: Option<OpenAiSettingsContent>,
pub mistral: Option<MistralSettingsContent>,
pub ollama: Option<OllamaSettingsContent>,
pub open_router: Option<OpenRouterSettingsContent>,
pub openai: Option<OpenAiSettingsContent>,
pub openai_compatible: Option<HashMap<Arc<str>, OpenAiCompatibleSettingsContent>>,
pub vercel: Option<VercelSettingsContent>,
pub x_ai: Option<XAiSettingsContent>,
#[serde(rename = "zed.dev")]
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
pub google: Option<GoogleSettingsContent>,
pub deepseek: Option<DeepseekSettingsContent>,
pub vercel: Option<VercelSettingsContent>,
pub mistral: Option<MistralSettingsContent>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -102,6 +109,12 @@ pub struct OpenAiSettingsContent {
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenAiCompatibleSettingsContent {
pub api_url: String,
pub available_models: Vec<provider::open_ai_compatible::AvailableModel>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct VercelSettingsContent {
pub api_url: Option<String>,
@ -114,6 +127,12 @@ pub struct GoogleSettingsContent {
pub available_models: Option<Vec<provider::google::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct XAiSettingsContent {
pub api_url: Option<String>,
pub available_models: Option<Vec<provider::x_ai::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct ZedDotDevSettingsContent {
available_models: Option<Vec<cloud::AvailableModel>>,
@ -219,6 +238,19 @@ impl settings::Settings for AllLanguageModelSettings {
openai.as_ref().and_then(|s| s.available_models.clone()),
);
// OpenAI Compatible
if let Some(openai_compatible) = value.openai_compatible.clone() {
for (id, openai_compatible_settings) in openai_compatible {
settings.openai_compatible.insert(
id,
OpenAiCompatibleSettings {
api_url: openai_compatible_settings.api_url,
available_models: openai_compatible_settings.available_models,
},
);
}
}
// Vercel
let vercel = value.vercel.clone();
merge(
@ -230,6 +262,18 @@ impl settings::Settings for AllLanguageModelSettings {
vercel.as_ref().and_then(|s| s.available_models.clone()),
);
// XAI
let x_ai = value.x_ai.clone();
merge(
&mut settings.x_ai.api_url,
x_ai.as_ref().and_then(|s| s.api_url.clone()),
);
merge(
&mut settings.x_ai.available_models,
x_ai.as_ref().and_then(|s| s.available_models.clone()),
);
// ZedDotDev
merge(
&mut settings.zed_dot_dev.available_models,
value