Merge branch 'main' into push-trzsxukkukpr
This commit is contained in:
commit
2dde3fd58c
553 changed files with 37661 additions and 11007 deletions
|
@ -12,6 +12,7 @@ workspace = true
|
|||
path = "src/language_models.rs"
|
||||
|
||||
[dependencies]
|
||||
ai_onboarding.workspace = true
|
||||
anthropic = { workspace = true, features = ["schemars"] }
|
||||
anyhow.workspace = true
|
||||
aws-config = { workspace = true, features = ["behavior-version-latest"] }
|
||||
|
@ -25,11 +26,10 @@ client.workspace = true
|
|||
collections.workspace = true
|
||||
component.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
convert_case.workspace = true
|
||||
copilot.workspace = true
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
google_ai = { workspace = true, features = ["schemars"] }
|
||||
gpui.workspace = true
|
||||
|
@ -44,6 +44,7 @@ ollama = { workspace = true, features = ["schemars"] }
|
|||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
open_router = { workspace = true, features = ["schemars"] }
|
||||
vercel = { workspace = true, features = ["schemars"] }
|
||||
x_ai = { workspace = true, features = ["schemars"] }
|
||||
partial-json-fixer.workspace = true
|
||||
proto.workspace = true
|
||||
release_channel.workspace = true
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use ::settings::{Settings, SettingsStore};
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashSet;
|
||||
use gpui::{App, Context, Entity};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
|
||||
use provider::deepseek::DeepSeekLanguageModelProvider;
|
||||
|
||||
pub mod provider;
|
||||
|
@ -18,16 +20,81 @@ use crate::provider::lmstudio::LmStudioLanguageModelProvider;
|
|||
use crate::provider::mistral::MistralLanguageModelProvider;
|
||||
use crate::provider::ollama::OllamaLanguageModelProvider;
|
||||
use crate::provider::open_ai::OpenAiLanguageModelProvider;
|
||||
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
|
||||
use crate::provider::open_router::OpenRouterLanguageModelProvider;
|
||||
use crate::provider::vercel::VercelLanguageModelProvider;
|
||||
use crate::provider::x_ai::XAiLanguageModelProvider;
|
||||
pub use crate::settings::*;
|
||||
|
||||
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
||||
crate::settings::init(cx);
|
||||
crate::settings::init_settings(cx);
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_language_model_providers(registry, user_store, client, cx);
|
||||
register_language_model_providers(registry, user_store, client.clone(), cx);
|
||||
});
|
||||
|
||||
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
|
||||
.openai_compatible
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_openai_compatible_providers(
|
||||
registry,
|
||||
&HashSet::default(),
|
||||
&openai_compatible_providers,
|
||||
client.clone(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
|
||||
.openai_compatible
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect::<HashSet<_>>();
|
||||
if openai_compatible_providers_new != openai_compatible_providers {
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_openai_compatible_providers(
|
||||
registry,
|
||||
&openai_compatible_providers,
|
||||
&openai_compatible_providers_new,
|
||||
client.clone(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
openai_compatible_providers = openai_compatible_providers_new;
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn register_openai_compatible_providers(
|
||||
registry: &mut LanguageModelRegistry,
|
||||
old: &HashSet<Arc<str>>,
|
||||
new: &HashSet<Arc<str>>,
|
||||
client: Arc<Client>,
|
||||
cx: &mut Context<LanguageModelRegistry>,
|
||||
) {
|
||||
for provider_id in old {
|
||||
if !new.contains(provider_id) {
|
||||
registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
|
||||
}
|
||||
}
|
||||
|
||||
for provider_id in new {
|
||||
if !old.contains(provider_id) {
|
||||
registry.register_provider(
|
||||
OpenAiCompatibleLanguageModelProvider::new(
|
||||
provider_id.clone(),
|
||||
client.http_client(),
|
||||
cx,
|
||||
),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_language_model_providers(
|
||||
|
@ -81,5 +148,6 @@ fn register_language_model_providers(
|
|||
VercelLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
|
||||
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
|
||||
}
|
||||
|
|
|
@ -8,5 +8,7 @@ pub mod lmstudio;
|
|||
pub mod mistral;
|
||||
pub mod ollama;
|
||||
pub mod open_ai;
|
||||
pub mod open_ai_compatible;
|
||||
pub mod open_router;
|
||||
pub mod vercel;
|
||||
pub mod x_ai;
|
||||
|
|
|
@ -663,7 +663,9 @@ pub fn into_anthropic(
|
|||
} else {
|
||||
Some(anthropic::StringOrContents::String(system_message))
|
||||
},
|
||||
thinking: if let AnthropicModelMode::Thinking { budget_tokens } = mode {
|
||||
thinking: if request.thinking_allowed
|
||||
&& let AnthropicModelMode::Thinking { budget_tokens } = mode
|
||||
{
|
||||
Some(anthropic::Thinking::Enabled { budget_tokens })
|
||||
} else {
|
||||
None
|
||||
|
@ -1108,6 +1110,7 @@ mod tests {
|
|||
temperature: None,
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
thinking_allowed: true,
|
||||
};
|
||||
|
||||
let anthropic_request = into_anthropic(
|
||||
|
|
|
@ -243,7 +243,7 @@ impl State {
|
|||
|
||||
pub struct BedrockLanguageModelProvider {
|
||||
http_client: AwsHttpClient,
|
||||
handler: tokio::runtime::Handle,
|
||||
handle: tokio::runtime::Handle,
|
||||
state: gpui::Entity<State>,
|
||||
}
|
||||
|
||||
|
@ -258,13 +258,9 @@ impl BedrockLanguageModelProvider {
|
|||
}),
|
||||
});
|
||||
|
||||
let tokio_handle = Tokio::handle(cx);
|
||||
|
||||
let coerced_client = AwsHttpClient::new(http_client.clone(), tokio_handle.clone());
|
||||
|
||||
Self {
|
||||
http_client: coerced_client,
|
||||
handler: tokio_handle.clone(),
|
||||
http_client: AwsHttpClient::new(http_client.clone()),
|
||||
handle: Tokio::handle(cx),
|
||||
state,
|
||||
}
|
||||
}
|
||||
|
@ -274,7 +270,7 @@ impl BedrockLanguageModelProvider {
|
|||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
http_client: self.http_client.clone(),
|
||||
handler: self.handler.clone(),
|
||||
handle: self.handle.clone(),
|
||||
state: self.state.clone(),
|
||||
client: OnceCell::new(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
|
@ -375,7 +371,7 @@ struct BedrockModel {
|
|||
id: LanguageModelId,
|
||||
model: Model,
|
||||
http_client: AwsHttpClient,
|
||||
handler: tokio::runtime::Handle,
|
||||
handle: tokio::runtime::Handle,
|
||||
client: OnceCell<BedrockClient>,
|
||||
state: gpui::Entity<State>,
|
||||
request_limiter: RateLimiter,
|
||||
|
@ -447,7 +443,7 @@ impl BedrockModel {
|
|||
}
|
||||
}
|
||||
|
||||
let config = self.handler.block_on(config_builder.load());
|
||||
let config = self.handle.block_on(config_builder.load());
|
||||
anyhow::Ok(BedrockClient::new(&config))
|
||||
})
|
||||
.context("initializing Bedrock client")?;
|
||||
|
@ -799,7 +795,9 @@ pub fn into_bedrock(
|
|||
max_tokens: max_output_tokens,
|
||||
system: Some(system_message),
|
||||
tools: Some(tool_config),
|
||||
thinking: if let BedrockModelMode::Thinking { budget_tokens } = mode {
|
||||
thinking: if request.thinking_allowed
|
||||
&& let BedrockModelMode::Thinking { budget_tokens } = mode
|
||||
{
|
||||
Some(bedrock::Thinking::Enabled { budget_tokens })
|
||||
} else {
|
||||
None
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use ai_onboarding::YoungAccountBanner;
|
||||
use anthropic::AnthropicModelMode;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag};
|
||||
use futures::{
|
||||
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
|
||||
};
|
||||
|
@ -137,7 +137,6 @@ impl State {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||
let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
|
||||
|
||||
Self {
|
||||
client: client.clone(),
|
||||
|
@ -165,47 +164,10 @@ impl State {
|
|||
.await;
|
||||
}
|
||||
|
||||
let response = Self::fetch_models(client, llm_api_token, use_cloud).await?;
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
let mut models = Vec::new();
|
||||
|
||||
for model in response.models {
|
||||
models.push(Arc::new(model.clone()));
|
||||
|
||||
// Right now we represent thinking variants of models as separate models on the client,
|
||||
// so we need to insert variants for any model that supports thinking.
|
||||
if model.supports_thinking {
|
||||
models.push(Arc::new(zed_llm_client::LanguageModel {
|
||||
id: zed_llm_client::LanguageModelId(
|
||||
format!("{}-thinking", model.id).into(),
|
||||
),
|
||||
display_name: format!("{} Thinking", model.display_name),
|
||||
..model
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
this.default_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_model)
|
||||
.cloned();
|
||||
this.default_fast_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_fast_model)
|
||||
.cloned();
|
||||
this.recommended_models = response
|
||||
.recommended_models
|
||||
.iter()
|
||||
.filter_map(|id| models.iter().find(|model| &model.id == id))
|
||||
.cloned()
|
||||
.collect();
|
||||
this.models = models;
|
||||
cx.notify();
|
||||
})
|
||||
})??;
|
||||
|
||||
anyhow::Ok(())
|
||||
let response = Self::fetch_models(client, llm_api_token).await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_models(response, cx);
|
||||
})
|
||||
})
|
||||
.await
|
||||
.context("failed to fetch Zed models")
|
||||
|
@ -216,12 +178,15 @@ impl State {
|
|||
}),
|
||||
_llm_token_subscription: cx.subscribe(
|
||||
&refresh_llm_token_listener,
|
||||
|this, _listener, _event, cx| {
|
||||
move |this, _listener, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_api_token = this.llm_api_token.clone();
|
||||
cx.spawn(async move |_this, _cx| {
|
||||
cx.spawn(async move |this, cx| {
|
||||
llm_api_token.refresh(&client).await?;
|
||||
anyhow::Ok(())
|
||||
let response = Self::fetch_models(client, llm_api_token).await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_models(response, cx);
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
|
@ -264,21 +229,51 @@ impl State {
|
|||
}));
|
||||
}
|
||||
|
||||
fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
|
||||
let mut models = Vec::new();
|
||||
|
||||
for model in response.models {
|
||||
models.push(Arc::new(model.clone()));
|
||||
|
||||
// Right now we represent thinking variants of models as separate models on the client,
|
||||
// so we need to insert variants for any model that supports thinking.
|
||||
if model.supports_thinking {
|
||||
models.push(Arc::new(zed_llm_client::LanguageModel {
|
||||
id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
|
||||
display_name: format!("{} Thinking", model.display_name),
|
||||
..model
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
self.default_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_model)
|
||||
.cloned();
|
||||
self.default_fast_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_fast_model)
|
||||
.cloned();
|
||||
self.recommended_models = response
|
||||
.recommended_models
|
||||
.iter()
|
||||
.filter_map(|id| models.iter().find(|model| &model.id == id))
|
||||
.cloned()
|
||||
.collect();
|
||||
self.models = models;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
async fn fetch_models(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
use_cloud: bool,
|
||||
) -> Result<ListModelsResponse> {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri(
|
||||
http_client
|
||||
.build_zed_llm_url("/models", &[], use_cloud)?
|
||||
.as_ref(),
|
||||
)
|
||||
.uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(AsyncBody::empty())?;
|
||||
let mut response = http_client
|
||||
|
@ -506,7 +501,7 @@ fn render_accept_terms(
|
|||
)
|
||||
.child({
|
||||
match view_kind {
|
||||
LanguageModelProviderTosView::PromptEditorPopup => {
|
||||
LanguageModelProviderTosView::TextThreadPopup => {
|
||||
button_container.w_full().justify_end()
|
||||
}
|
||||
LanguageModelProviderTosView::Configuration => {
|
||||
|
@ -542,7 +537,6 @@ impl CloudLanguageModel {
|
|||
llm_api_token: LlmApiToken,
|
||||
app_version: Option<SemanticVersion>,
|
||||
body: CompletionBody,
|
||||
use_cloud: bool,
|
||||
) -> Result<PerformLlmCompletionResponse> {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
|
@ -550,11 +544,9 @@ impl CloudLanguageModel {
|
|||
let mut refreshed_token = false;
|
||||
|
||||
loop {
|
||||
let request_builder = http_client::Request::builder().method(Method::POST).uri(
|
||||
http_client
|
||||
.build_zed_llm_url("/completions", &[], use_cloud)?
|
||||
.as_ref(),
|
||||
);
|
||||
let request_builder = http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
|
||||
let request_builder = if let Some(app_version) = app_version {
|
||||
request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
|
||||
} else {
|
||||
|
@ -653,8 +645,62 @@ struct ApiError {
|
|||
headers: HeaderMap<HeaderValue>,
|
||||
}
|
||||
|
||||
/// Represents error responses from Zed's cloud API.
|
||||
///
|
||||
/// Example JSON for an upstream HTTP error:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "code": "upstream_http_error",
|
||||
/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
|
||||
/// "upstream_status": 503
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct CloudApiError {
|
||||
code: String,
|
||||
message: String,
|
||||
#[serde(default)]
|
||||
#[serde(deserialize_with = "deserialize_optional_status_code")]
|
||||
upstream_status: Option<StatusCode>,
|
||||
#[serde(default)]
|
||||
retry_after: Option<f64>,
|
||||
}
|
||||
|
||||
fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let opt: Option<u16> = Option::deserialize(deserializer)?;
|
||||
Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
|
||||
}
|
||||
|
||||
impl From<ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: ApiError) -> Self {
|
||||
if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
|
||||
if cloud_error.code.starts_with("upstream_http_") {
|
||||
let status = if let Some(status) = cloud_error.upstream_status {
|
||||
status
|
||||
} else if cloud_error.code.ends_with("_error") {
|
||||
error.status
|
||||
} else {
|
||||
// If there's a status code in the code string (e.g. "upstream_http_429")
|
||||
// then use that; otherwise, see if the JSON contains a status code.
|
||||
cloud_error
|
||||
.code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code_str| code_str.parse::<u16>().ok())
|
||||
.and_then(|code| StatusCode::from_u16(code).ok())
|
||||
.unwrap_or(error.status)
|
||||
};
|
||||
|
||||
return LanguageModelCompletionError::UpstreamProviderError {
|
||||
message: cloud_error.message,
|
||||
status,
|
||||
retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let retry_after = None;
|
||||
LanguageModelCompletionError::from_http_status(
|
||||
PROVIDER_NAME,
|
||||
|
@ -781,7 +827,6 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let model_id = self.model.id.to_string();
|
||||
let generate_content_request =
|
||||
into_google(request, model_id.clone(), GoogleModelMode::Default);
|
||||
let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
|
||||
async move {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
@ -797,7 +842,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
.method(Method::POST)
|
||||
.uri(
|
||||
http_client
|
||||
.build_zed_llm_url("/count_tokens", &[], use_cloud)?
|
||||
.build_zed_llm_url("/count_tokens", &[])?
|
||||
.as_ref(),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
|
@ -846,9 +891,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let intent = request.intent;
|
||||
let mode = request.mode;
|
||||
let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
|
||||
let use_cloud = cx
|
||||
.update(|cx| cx.has_flag::<ZedCloudFeatureFlag>())
|
||||
.unwrap_or(false);
|
||||
let thinking_allowed = request.thinking_allowed;
|
||||
match self.model.provider {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic => {
|
||||
let request = into_anthropic(
|
||||
|
@ -856,7 +899,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
self.model.id.to_string(),
|
||||
1.0,
|
||||
self.model.max_output_tokens as u64,
|
||||
if self.model.id.0.ends_with("-thinking") {
|
||||
if thinking_allowed && self.model.id.0.ends_with("-thinking") {
|
||||
AnthropicModelMode::Thinking {
|
||||
budget_tokens: Some(4_096),
|
||||
}
|
||||
|
@ -886,7 +929,6 @@ impl LanguageModel for CloudLanguageModel {
|
|||
provider_request: serde_json::to_value(&request)
|
||||
.map_err(|e| anyhow!(e))?,
|
||||
},
|
||||
use_cloud,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| match err.downcast::<ApiError>() {
|
||||
|
@ -939,7 +981,6 @@ impl LanguageModel for CloudLanguageModel {
|
|||
provider_request: serde_json::to_value(&request)
|
||||
.map_err(|e| anyhow!(e))?,
|
||||
},
|
||||
use_cloud,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
@ -980,7 +1021,6 @@ impl LanguageModel for CloudLanguageModel {
|
|||
provider_request: serde_json::to_value(&request)
|
||||
.map_err(|e| anyhow!(e))?,
|
||||
},
|
||||
use_cloud,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
@ -1087,6 +1127,7 @@ struct ZedAiConfiguration {
|
|||
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
|
||||
eligible_for_trial: bool,
|
||||
has_accepted_terms_of_service: bool,
|
||||
account_too_young: bool,
|
||||
accept_terms_of_service_in_progress: bool,
|
||||
accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
|
||||
sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
|
||||
|
@ -1094,89 +1135,98 @@ struct ZedAiConfiguration {
|
|||
|
||||
impl RenderOnce for ZedAiConfiguration {
|
||||
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
|
||||
const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
|
||||
let young_account_banner = YoungAccountBanner;
|
||||
|
||||
let is_pro = self.plan == Some(proto::Plan::ZedPro);
|
||||
let subscription_text = match (self.plan, self.subscription_period) {
|
||||
(Some(proto::Plan::ZedPro), Some(_)) => {
|
||||
"You have access to Zed's hosted LLMs through your Zed Pro subscription."
|
||||
"You have access to Zed's hosted models through your Pro subscription."
|
||||
}
|
||||
(Some(proto::Plan::ZedProTrial), Some(_)) => {
|
||||
"You have access to Zed's hosted LLMs through your Zed Pro trial."
|
||||
"You have access to Zed's hosted models through your Pro trial."
|
||||
}
|
||||
(Some(proto::Plan::Free), Some(_)) => {
|
||||
"You have basic access to Zed's hosted LLMs through your Zed Free subscription."
|
||||
"You have basic access to Zed's hosted models through the Free plan."
|
||||
}
|
||||
_ => {
|
||||
if self.eligible_for_trial {
|
||||
"Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
|
||||
"Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
|
||||
} else {
|
||||
"Subscribe for access to Zed's hosted LLMs."
|
||||
"Subscribe for access to Zed's hosted models."
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let manage_subscription_buttons = if is_pro {
|
||||
h_flex().child(
|
||||
Button::new("manage_settings", "Manage Subscription")
|
||||
.style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
|
||||
)
|
||||
Button::new("manage_settings", "Manage Subscription")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
|
||||
.into_any_element()
|
||||
} else if self.plan.is_none() || self.eligible_for_trial {
|
||||
Button::new("start_trial", "Start 14-day Free Pro Trial")
|
||||
.full_width()
|
||||
.style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
|
||||
.into_any_element()
|
||||
} else {
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("learn_more", "Learn more")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)),
|
||||
)
|
||||
.child(
|
||||
Button::new(
|
||||
"upgrade",
|
||||
if self.plan.is_none() && self.eligible_for_trial {
|
||||
"Start Trial"
|
||||
} else {
|
||||
"Upgrade"
|
||||
},
|
||||
)
|
||||
.style(ButtonStyle::Subtle)
|
||||
.color(Color::Accent)
|
||||
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
|
||||
)
|
||||
Button::new("upgrade", "Upgrade to Pro")
|
||||
.full_width()
|
||||
.style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
|
||||
.into_any_element()
|
||||
};
|
||||
|
||||
if self.is_connected {
|
||||
v_flex()
|
||||
.gap_3()
|
||||
.w_full()
|
||||
.when(!self.has_accepted_terms_of_service, |this| {
|
||||
this.child(render_accept_terms(
|
||||
LanguageModelProviderTosView::Configuration,
|
||||
self.accept_terms_of_service_in_progress,
|
||||
{
|
||||
let callback = self.accept_terms_of_service_callback.clone();
|
||||
move |window, cx| (callback)(window, cx)
|
||||
},
|
||||
))
|
||||
})
|
||||
.when(self.has_accepted_terms_of_service, |this| {
|
||||
this.child(subscription_text)
|
||||
.child(manage_subscription_buttons)
|
||||
})
|
||||
} else {
|
||||
v_flex()
|
||||
if !self.is_connected {
|
||||
return v_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Use Zed AI to access hosted language models."))
|
||||
.child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
|
||||
.child(
|
||||
Button::new("sign_in", "Sign In")
|
||||
Button::new("sign_in", "Sign In to use Zed AI")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.full_width()
|
||||
.on_click({
|
||||
let callback = self.sign_in_callback.clone();
|
||||
move |_, window, cx| (callback)(window, cx)
|
||||
}),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.w_full()
|
||||
.when(!self.has_accepted_terms_of_service, |this| {
|
||||
this.child(render_accept_terms(
|
||||
LanguageModelProviderTosView::Configuration,
|
||||
self.accept_terms_of_service_in_progress,
|
||||
{
|
||||
let callback = self.accept_terms_of_service_callback.clone();
|
||||
move |window, cx| (callback)(window, cx)
|
||||
},
|
||||
))
|
||||
})
|
||||
.map(|this| {
|
||||
if self.has_accepted_terms_of_service && self.account_too_young {
|
||||
this.child(young_account_banner).child(
|
||||
Button::new("upgrade", "Upgrade to Pro")
|
||||
.style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.full_width()
|
||||
.on_click(|_, _, cx| {
|
||||
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
|
||||
}),
|
||||
)
|
||||
} else if self.has_accepted_terms_of_service {
|
||||
this.text_sm()
|
||||
.child(subscription_text)
|
||||
.child(manage_subscription_buttons)
|
||||
} else {
|
||||
this
|
||||
}
|
||||
})
|
||||
.when(self.has_accepted_terms_of_service, |this| this)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1225,6 +1275,7 @@ impl Render for ConfigurationView {
|
|||
subscription_period: user_store.subscription_period(),
|
||||
eligible_for_trial: user_store.trial_started_at().is_none(),
|
||||
has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
|
||||
account_too_young: user_store.account_too_young(),
|
||||
accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
|
||||
accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
|
||||
sign_in_callback: self.sign_in_callback.clone(),
|
||||
|
@ -1242,6 +1293,7 @@ impl Component for ZedAiConfiguration {
|
|||
is_connected: bool,
|
||||
plan: Option<proto::Plan>,
|
||||
eligible_for_trial: bool,
|
||||
account_too_young: bool,
|
||||
has_accepted_terms_of_service: bool,
|
||||
) -> AnyElement {
|
||||
ZedAiConfiguration {
|
||||
|
@ -1252,6 +1304,7 @@ impl Component for ZedAiConfiguration {
|
|||
.then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
|
||||
eligible_for_trial,
|
||||
has_accepted_terms_of_service,
|
||||
account_too_young,
|
||||
accept_terms_of_service_in_progress: false,
|
||||
accept_terms_of_service_callback: Arc::new(|_, _| {}),
|
||||
sign_in_callback: Arc::new(|_, _| {}),
|
||||
|
@ -1264,33 +1317,188 @@ impl Component for ZedAiConfiguration {
|
|||
.p_4()
|
||||
.gap_4()
|
||||
.children(vec![
|
||||
single_example("Not connected", configuration(false, None, false, true)),
|
||||
single_example(
|
||||
"Not connected",
|
||||
configuration(false, None, false, false, true),
|
||||
),
|
||||
single_example(
|
||||
"Accept Terms of Service",
|
||||
configuration(true, None, true, false),
|
||||
configuration(true, None, true, false, false),
|
||||
),
|
||||
single_example(
|
||||
"No Plan - Not eligible for trial",
|
||||
configuration(true, None, false, true),
|
||||
configuration(true, None, false, false, true),
|
||||
),
|
||||
single_example(
|
||||
"No Plan - Eligible for trial",
|
||||
configuration(true, None, true, true),
|
||||
configuration(true, None, true, false, true),
|
||||
),
|
||||
single_example(
|
||||
"Free Plan",
|
||||
configuration(true, Some(proto::Plan::Free), true, true),
|
||||
configuration(true, Some(proto::Plan::Free), true, false, true),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro Trial Plan",
|
||||
configuration(true, Some(proto::Plan::ZedProTrial), true, true),
|
||||
configuration(true, Some(proto::Plan::ZedProTrial), true, false, true),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro Plan",
|
||||
configuration(true, Some(proto::Plan::ZedPro), true, true),
|
||||
configuration(true, Some(proto::Plan::ZedPro), true, false, true),
|
||||
),
|
||||
])
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use http_client::http::{HeaderMap, StatusCode};
|
||||
use language_model::LanguageModelCompletionError;
|
||||
|
||||
#[test]
|
||||
fn test_api_error_conversion_with_upstream_http_error() {
|
||||
// upstream_http_error with 503 status should become ServerOverloaded
|
||||
let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected UpstreamProviderError for upstream 503, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// upstream_http_error with 500 status should become ApiInternalServerError
|
||||
let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the OpenAI API: internal server error"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected UpstreamProviderError for upstream 500, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// upstream_http_error with 429 status should become RateLimitExceeded
|
||||
let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the Google API: rate limit exceeded"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected UpstreamProviderError for upstream 429, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
|
||||
let error_body = "Regular internal server error";
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider, PROVIDER_NAME);
|
||||
assert_eq!(message, "Regular internal server error");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for regular 500, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// upstream_http_429 format should be converted to UpstreamProviderError
|
||||
let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::UpstreamProviderError {
|
||||
message,
|
||||
status,
|
||||
retry_after,
|
||||
} => {
|
||||
assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
|
||||
assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
|
||||
assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected UpstreamProviderError for upstream_http_429, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
|
||||
// Invalid JSON in error body should fall back to regular error handling
|
||||
let error_body = "Not JSON at all";
|
||||
|
||||
let api_error = ApiError {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
body: error_body.to_string(),
|
||||
headers: HeaderMap::new(),
|
||||
};
|
||||
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
|
||||
assert_eq!(provider, PROVIDER_NAME);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for invalid JSON, got: {:?}",
|
||||
completion_error
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -94,6 +94,7 @@ pub struct State {
|
|||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY";
|
||||
const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY";
|
||||
|
||||
impl State {
|
||||
|
@ -151,6 +152,8 @@ impl State {
|
|||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, &cx)
|
||||
|
@ -559,11 +562,11 @@ pub fn into_google(
|
|||
stop_sequences: Some(request.stop),
|
||||
max_output_tokens: None,
|
||||
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
|
||||
thinking_config: match mode {
|
||||
GoogleModelMode::Thinking { budget_tokens } => {
|
||||
thinking_config: match (request.thinking_allowed, mode) {
|
||||
(true, GoogleModelMode::Thinking { budget_tokens }) => {
|
||||
budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
|
||||
}
|
||||
GoogleModelMode::Default => None,
|
||||
_ => None,
|
||||
},
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
|
@ -903,7 +906,7 @@ impl Render for ConfigurationView {
|
|||
)
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {GOOGLE_AI_API_KEY_VAR} environment variable and restart Zed."),
|
||||
format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."),
|
||||
)
|
||||
.size(LabelSize::Small).color(Color::Muted),
|
||||
)
|
||||
|
@ -922,7 +925,7 @@ impl Render for ConfigurationView {
|
|||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {GOOGLE_AI_API_KEY_VAR} environment variable.")
|
||||
format!("API key set in {GEMINI_API_KEY_VAR} environment variable.")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
})),
|
||||
|
@ -935,7 +938,7 @@ impl Render for ConfigurationView {
|
|||
.icon_position(IconPosition::Start)
|
||||
.disabled(env_var_set)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {GOOGLE_AI_API_KEY_VAR} environment variable.")))
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
|
|
|
@ -410,8 +410,20 @@ pub fn into_mistral(
|
|||
.push_part(mistral::MessagePart::Text { text: text.clone() });
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::ToolUse(_) | MessageContent::ToolResult(_) => {
|
||||
// Tool content is not supported in User messages for Mistral
|
||||
MessageContent::ToolUse(_) => {
|
||||
// Tool use is not supported in User messages for Mistral
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
let tool_content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
|
||||
}
|
||||
};
|
||||
messages.push(mistral::RequestMessage::Tool {
|
||||
content: tool_content,
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -482,24 +494,6 @@ pub fn into_mistral(
|
|||
}
|
||||
}
|
||||
|
||||
for message in &request.messages {
|
||||
for content in &message.content {
|
||||
if let MessageContent::ToolResult(tool_result) = content {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(mistral::RequestMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The Mistral API requires that tool messages be followed by assistant messages,
|
||||
// not user messages. When we have a tool->user sequence in the conversation,
|
||||
// we need to insert a placeholder assistant message to maintain proper conversation
|
||||
|
@ -911,6 +905,7 @@ mod tests {
|
|||
intent: None,
|
||||
mode: None,
|
||||
stop: vec![],
|
||||
thinking_allowed: true,
|
||||
};
|
||||
|
||||
let mistral_request = into_mistral(request, "mistral-small-latest".into(), None);
|
||||
|
@ -943,6 +938,7 @@ mod tests {
|
|||
intent: None,
|
||||
mode: None,
|
||||
stop: vec![],
|
||||
thinking_allowed: true,
|
||||
};
|
||||
|
||||
let mistral_request = into_mistral(request, "pixtral-12b-latest".into(), None);
|
||||
|
|
|
@ -415,7 +415,10 @@ impl OllamaLanguageModel {
|
|||
temperature: request.temperature.or(Some(1.0)),
|
||||
..Default::default()
|
||||
}),
|
||||
think: self.model.supports_thinking,
|
||||
think: self
|
||||
.model
|
||||
.supports_thinking
|
||||
.map(|supports_thinking| supports_thinking && request.thinking_allowed),
|
||||
tools: request.tools.into_iter().map(tool_into_ollama).collect(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
|
||||
use fs::Fs;
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
|
@ -18,7 +17,7 @@ use menu;
|
|||
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore, update_settings_file};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
|
@ -28,7 +27,6 @@ use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
|||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::OpenAiSettingsContent;
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
|
||||
|
@ -621,26 +619,32 @@ struct RawToolCall {
|
|||
arguments: String,
|
||||
}
|
||||
|
||||
pub(crate) fn collect_tiktoken_messages(
|
||||
request: LanguageModelRequest,
|
||||
) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
|
||||
request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: Model,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let messages = collect_tiktoken_messages(request);
|
||||
|
||||
match model {
|
||||
Model::Custom { max_tokens, .. } => {
|
||||
|
@ -678,7 +682,6 @@ pub fn count_open_ai_tokens(
|
|||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<SingleLineInput>,
|
||||
api_url_editor: Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
}
|
||||
|
@ -691,23 +694,6 @@ impl ConfigurationView {
|
|||
cx,
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
.label("API key")
|
||||
});
|
||||
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
let api_url_editor = cx.new(|cx| {
|
||||
let input = SingleLineInput::new(window, cx, open_ai::OPEN_AI_API_URL).label("API URL");
|
||||
|
||||
if !api_url.is_empty() {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text(&*api_url, window, cx);
|
||||
});
|
||||
}
|
||||
input
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
|
@ -735,7 +721,6 @@ impl ConfigurationView {
|
|||
|
||||
Self {
|
||||
api_key_editor,
|
||||
api_url_editor,
|
||||
state,
|
||||
load_credentials_task,
|
||||
}
|
||||
|
@ -783,57 +768,6 @@ impl ConfigurationView {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
fn save_api_url(&mut self, cx: &mut Context<Self>) {
|
||||
let api_url = self
|
||||
.api_url_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
let current_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
let effective_current_url = if current_url.is_empty() {
|
||||
open_ai::OPEN_AI_API_URL
|
||||
} else {
|
||||
¤t_url
|
||||
};
|
||||
|
||||
if !api_url.is_empty() && api_url != effective_current_url {
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |settings, _| {
|
||||
if let Some(settings) = settings.openai.as_mut() {
|
||||
settings.api_url = Some(api_url.clone());
|
||||
} else {
|
||||
settings.openai = Some(OpenAiSettingsContent {
|
||||
api_url: Some(api_url.clone()),
|
||||
available_models: None,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_url_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
|
||||
if let Some(settings) = settings.openai.as_mut() {
|
||||
settings.api_url = None;
|
||||
}
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
!self.state.read(cx).is_authenticated()
|
||||
}
|
||||
|
@ -846,7 +780,6 @@ impl Render for ConfigurationView {
|
|||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
|
||||
.child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:"))
|
||||
.child(
|
||||
List::new()
|
||||
|
@ -910,59 +843,34 @@ impl Render for ConfigurationView {
|
|||
.into_any()
|
||||
};
|
||||
|
||||
let custom_api_url_set =
|
||||
AllLanguageModelSettings::get_global(cx).openai.api_url != open_ai::OPEN_AI_API_URL;
|
||||
|
||||
let api_url_section = if custom_api_url_set {
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new("Custom API URL configured.")),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-url", "Reset API URL")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.on_click(
|
||||
cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
v_flex()
|
||||
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
|
||||
this.save_api_url(cx);
|
||||
cx.notify();
|
||||
}))
|
||||
.mt_2()
|
||||
.pt_2()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.gap_1()
|
||||
.child(
|
||||
List::new()
|
||||
.child(InstructionListItem::text_only(
|
||||
"Optionally, you can change the base URL for the OpenAI API request.",
|
||||
))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Paste the new API endpoint below and hit enter",
|
||||
)),
|
||||
)
|
||||
.child(self.api_url_editor.clone())
|
||||
.into_any()
|
||||
};
|
||||
let compatible_api_section = h_flex()
|
||||
.mt_1p5()
|
||||
.gap_0p5()
|
||||
.flex_wrap()
|
||||
.when(self.should_render_editor(cx), |this| {
|
||||
this.pt_1p5()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Icon::new(IconName::Info)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(Label::new("Zed also supports OpenAI-compatible models.")),
|
||||
)
|
||||
.child(
|
||||
Button::new("docs", "Learn More")
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _window, cx| {
|
||||
cx.open_url("https://zed.dev/docs/ai/configuration#openai-api-compatible")
|
||||
}),
|
||||
);
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials…")).into_any()
|
||||
|
@ -970,7 +878,7 @@ impl Render for ConfigurationView {
|
|||
v_flex()
|
||||
.size_full()
|
||||
.child(api_key_section)
|
||||
.child(api_url_section)
|
||||
.child(compatible_api_section)
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
|
@ -999,6 +907,7 @@ mod tests {
|
|||
tool_choice: None,
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
thinking_allowed: true,
|
||||
};
|
||||
|
||||
// Validate that all models are supported by tiktoken-rs
|
||||
|
|
522
crates/language_models/src/provider/open_ai_compatible.rs
Normal file
522
crates/language_models/src/provider/open_ai_compatible.rs
Normal file
|
@ -0,0 +1,522 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
|
||||
use convert_case::{Case, Casing};
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, RateLimiter,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::{ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
|
||||
use ui::{ElevationIndex, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct OpenAiCompatibleSettings {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
pub struct OpenAiCompatibleLanguageModelProvider {
|
||||
id: LanguageModelProviderId,
|
||||
name: LanguageModelProviderName,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Entity<State>,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
id: Arc<str>,
|
||||
env_var_name: Arc<str>,
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
settings: OpenAiCompatibleSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = self.settings.api_url.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = self.settings.api_url.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let env_var_name = self.env_var_name.clone();
|
||||
let api_url = self.settings.api_url.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, &cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAiCompatibleLanguageModelProvider {
|
||||
pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
|
||||
AllLanguageModelSettings::get_global(cx)
|
||||
.openai_compatible
|
||||
.get(id)
|
||||
}
|
||||
|
||||
let state = cx.new(|cx| State {
|
||||
id: id.clone(),
|
||||
env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
|
||||
settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let Some(settings) = resolve_settings(&this.id, cx) else {
|
||||
return;
|
||||
};
|
||||
if &this.settings != settings {
|
||||
this.settings = settings.clone();
|
||||
cx.notify();
|
||||
}
|
||||
}),
|
||||
});
|
||||
|
||||
Self {
|
||||
id: id.clone().into(),
|
||||
name: id.into(),
|
||||
http_client,
|
||||
state,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(OpenAiCompatibleLanguageModel {
|
||||
id: LanguageModelId::from(model.name.clone()),
|
||||
provider_id: self.id.clone(),
|
||||
provider_name: self.name.clone(),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::AiOpenAiCompat
|
||||
}
|
||||
|
||||
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.state
|
||||
.read(cx)
|
||||
.settings
|
||||
.available_models
|
||||
.first()
|
||||
.map(|model| self.create_language_model(model.clone()))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
self.state
|
||||
.read(cx)
|
||||
.settings
|
||||
.available_models
|
||||
.iter()
|
||||
.map(|model| self.create_language_model(model.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &App) -> bool {
|
||||
self.state.read(cx).is_authenticated()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenAiCompatibleLanguageModel {
|
||||
id: LanguageModelId,
|
||||
provider_id: LanguageModelProviderId,
|
||||
provider_name: LanguageModelProviderName,
|
||||
model: AvailableModel,
|
||||
state: gpui::Entity<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl OpenAiCompatibleLanguageModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: open_ai::Request,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
|
||||
(state.api_key.clone(), state.settings.api_url.clone())
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let provider = self.provider_name.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
};
|
||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAiCompatibleLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(
|
||||
self.model
|
||||
.display_name
|
||||
.clone()
|
||||
.unwrap_or_else(|| self.model.name.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
self.provider_id.clone()
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
self.provider_name.clone()
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto => true,
|
||||
LanguageModelToolChoice::Any => true,
|
||||
LanguageModelToolChoice::None => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("openai/{}", self.model.name)
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_tokens
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
let max_token_count = self.max_token_count();
|
||||
cx.background_spawn(async move {
|
||||
let messages = super::open_ai::collect_tiktoken_messages(request);
|
||||
let model = if max_token_count >= 100_000 {
|
||||
// If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
|
||||
"gpt-4o"
|
||||
} else {
|
||||
// Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
|
||||
// supported with this tiktoken method
|
||||
"gpt-4"
|
||||
};
|
||||
tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = OpenAiEventMapper::new();
|
||||
Ok(mapper.map_stream(completions.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let api_key_editor = cx.new(|cx| {
|
||||
SingleLineInput::new(
|
||||
window,
|
||||
cx,
|
||||
"000000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
let load_credentials_task = Some(cx.spawn_in(window, {
|
||||
let state = state.clone();
|
||||
async move |this, cx| {
|
||||
if let Some(task) = state
|
||||
.update(cx, |state, cx| state.authenticate(cx))
|
||||
.log_err()
|
||||
{
|
||||
// We don't log an error, because "not signed in" is also an error.
|
||||
let _ = task.await;
|
||||
}
|
||||
this.update(cx, |this, cx| {
|
||||
this.load_credentials_task = None;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}));
|
||||
|
||||
Self {
|
||||
api_key_editor,
|
||||
state,
|
||||
load_credentials_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
return;
|
||||
}
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
!self.state.read(cx).is_authenticated()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_name = self.state.read(cx).env_var_name.clone();
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(Label::new("To use Zed's assistant with an OpenAI compatible provider, you need to add an API key."))
|
||||
.child(
|
||||
div()
|
||||
.pt(DynamicSpacing::Base04.rems(cx))
|
||||
.child(self.api_key_editor.clone())
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {env_var_name} environment variable and restart Zed."),
|
||||
)
|
||||
.size(LabelSize::Small).color(Color::Muted),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {env_var_name} environment variable.")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-key", "Reset API Key")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
.into_any()
|
||||
};
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials…")).into_any()
|
||||
} else {
|
||||
v_flex().size_full().child(api_key_section).into_any()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -376,7 +376,7 @@ impl LanguageModel for OpenRouterLanguageModel {
|
|||
|
||||
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
||||
let model_id = self.model.id().trim().to_lowercase();
|
||||
if model_id.contains("gemini") {
|
||||
if model_id.contains("gemini") || model_id.contains("grok-4") {
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset
|
||||
} else {
|
||||
LanguageModelToolSchemaFormat::JsonSchema
|
||||
|
@ -523,7 +523,9 @@ pub fn into_open_router(
|
|||
None
|
||||
},
|
||||
usage: open_router::RequestUsage { include: true },
|
||||
reasoning: if let OpenRouterModelMode::Thinking { budget_tokens } = model.mode {
|
||||
reasoning: if request.thinking_allowed
|
||||
&& let OpenRouterModelMode::Thinking { budget_tokens } = model.mode
|
||||
{
|
||||
Some(open_router::Reasoning {
|
||||
effort: None,
|
||||
max_tokens: budget_tokens,
|
||||
|
|
571
crates/language_models/src/provider/x_ai.rs
Normal file
571
crates/language_models/src/provider/x_ai.rs
Normal file
|
@ -0,0 +1,571 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::ResponseStreamEvent;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use x_ai::Model;
|
||||
|
||||
use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: &str = "x_ai";
|
||||
const PROVIDER_NAME: &str = "xAI";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct XAiSettings {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
pub struct XAiLanguageModelProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Entity<State>,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
const XAI_API_KEY_VAR: &str = "XAI_API_KEY";
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, &cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl XAiLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
|
||||
fn create_language_model(&self, model: x_ai::Model) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(XAiLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for XAiLanguageModelProvider {
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for XAiLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::AiXAi
|
||||
}
|
||||
|
||||
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
Some(self.create_language_model(x_ai::Model::default()))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
Some(self.create_language_model(x_ai::Model::default_fast()))
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
for model in x_ai::Model::iter() {
|
||||
if !matches!(model, x_ai::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
}
|
||||
}
|
||||
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.x_ai
|
||||
.available_models
|
||||
{
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
x_ai::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
models
|
||||
.into_values()
|
||||
.map(|model| self.create_language_model(model))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &App) -> bool {
|
||||
self.state.read(cx).is_authenticated()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct XAiLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: x_ai::Model,
|
||||
state: gpui::Entity<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl XAiLanguageModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: open_ai::Request,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
(state.api_key.clone(), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let api_key = api_key.context("Missing xAI API Key")?;
|
||||
let request =
|
||||
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for XAiLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
self.model.supports_tool()
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
self.model.supports_images()
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto
|
||||
| LanguageModelToolChoice::Any
|
||||
| LanguageModelToolChoice::None => true,
|
||||
}
|
||||
}
|
||||
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
||||
let model_id = self.model.id().trim().to_lowercase();
|
||||
if model_id.eq(x_ai::Model::Grok4.id()) {
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset
|
||||
} else {
|
||||
LanguageModelToolSchemaFormat::JsonSchema
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("x_ai/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
count_xai_tokens(request, self.model.clone(), cx)
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let request = crate::provider::open_ai::into_open_ai(
|
||||
request,
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.max_output_tokens(),
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
|
||||
Ok(mapper.map_stream(completions.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_xai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: Model,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let model_name = if model.max_token_count() >= 100_000 {
|
||||
"gpt-4o"
|
||||
} else {
|
||||
"gpt-4"
|
||||
};
|
||||
tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let api_key_editor = cx.new(|cx| {
|
||||
SingleLineInput::new(
|
||||
window,
|
||||
cx,
|
||||
"xai-0000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
.label("API key")
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
let load_credentials_task = Some(cx.spawn_in(window, {
|
||||
let state = state.clone();
|
||||
async move |this, cx| {
|
||||
if let Some(task) = state
|
||||
.update(cx, |state, cx| state.authenticate(cx))
|
||||
.log_err()
|
||||
{
|
||||
// We don't log an error, because "not signed in" is also an error.
|
||||
let _ = task.await;
|
||||
}
|
||||
this.update(cx, |this, cx| {
|
||||
this.load_credentials_task = None;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}));
|
||||
|
||||
Self {
|
||||
api_key_editor,
|
||||
state,
|
||||
load_credentials_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
return;
|
||||
}
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
!self.state.read(cx).is_authenticated()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(Label::new("To use Zed's agent with xAI, you need to add an API key. Follow these steps:"))
|
||||
.child(
|
||||
List::new()
|
||||
.child(InstructionListItem::new(
|
||||
"Create one by visiting",
|
||||
Some("xAI console"),
|
||||
Some("https://console.x.ai/team/default/api-keys"),
|
||||
))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Paste your API key below and hit enter to start using the agent",
|
||||
)),
|
||||
)
|
||||
.child(self.api_key_editor.clone())
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed."
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new("Note that xAI is a custom OpenAI-compatible provider.")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {XAI_API_KEY_VAR} environment variable.")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-key", "Reset API Key")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
.into_any()
|
||||
};
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials…")).into_any()
|
||||
} else {
|
||||
v_flex().size_full().child(api_key_section).into_any()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use gpui::App;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -15,12 +18,14 @@ use crate::provider::{
|
|||
mistral::MistralSettings,
|
||||
ollama::OllamaSettings,
|
||||
open_ai::OpenAiSettings,
|
||||
open_ai_compatible::OpenAiCompatibleSettings,
|
||||
open_router::OpenRouterSettings,
|
||||
vercel::VercelSettings,
|
||||
x_ai::XAiSettings,
|
||||
};
|
||||
|
||||
/// Initializes the language model settings.
|
||||
pub fn init(cx: &mut App) {
|
||||
pub fn init_settings(cx: &mut App) {
|
||||
AllLanguageModelSettings::register(cx);
|
||||
}
|
||||
|
||||
|
@ -28,33 +33,35 @@ pub fn init(cx: &mut App) {
|
|||
pub struct AllLanguageModelSettings {
|
||||
pub anthropic: AnthropicSettings,
|
||||
pub bedrock: AmazonBedrockSettings,
|
||||
pub ollama: OllamaSettings,
|
||||
pub openai: OpenAiSettings,
|
||||
pub open_router: OpenRouterSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
pub google: GoogleSettings,
|
||||
pub vercel: VercelSettings,
|
||||
|
||||
pub lmstudio: LmStudioSettings,
|
||||
pub deepseek: DeepSeekSettings,
|
||||
pub google: GoogleSettings,
|
||||
pub lmstudio: LmStudioSettings,
|
||||
pub mistral: MistralSettings,
|
||||
pub ollama: OllamaSettings,
|
||||
pub open_router: OpenRouterSettings,
|
||||
pub openai: OpenAiSettings,
|
||||
pub openai_compatible: HashMap<Arc<str>, OpenAiCompatibleSettings>,
|
||||
pub vercel: VercelSettings,
|
||||
pub x_ai: XAiSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct AllLanguageModelSettingsContent {
|
||||
pub anthropic: Option<AnthropicSettingsContent>,
|
||||
pub bedrock: Option<AmazonBedrockSettingsContent>,
|
||||
pub ollama: Option<OllamaSettingsContent>,
|
||||
pub deepseek: Option<DeepseekSettingsContent>,
|
||||
pub google: Option<GoogleSettingsContent>,
|
||||
pub lmstudio: Option<LmStudioSettingsContent>,
|
||||
pub openai: Option<OpenAiSettingsContent>,
|
||||
pub mistral: Option<MistralSettingsContent>,
|
||||
pub ollama: Option<OllamaSettingsContent>,
|
||||
pub open_router: Option<OpenRouterSettingsContent>,
|
||||
pub openai: Option<OpenAiSettingsContent>,
|
||||
pub openai_compatible: Option<HashMap<Arc<str>, OpenAiCompatibleSettingsContent>>,
|
||||
pub vercel: Option<VercelSettingsContent>,
|
||||
pub x_ai: Option<XAiSettingsContent>,
|
||||
#[serde(rename = "zed.dev")]
|
||||
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
|
||||
pub google: Option<GoogleSettingsContent>,
|
||||
pub deepseek: Option<DeepseekSettingsContent>,
|
||||
pub vercel: Option<VercelSettingsContent>,
|
||||
|
||||
pub mistral: Option<MistralSettingsContent>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
|
@ -102,6 +109,12 @@ pub struct OpenAiSettingsContent {
|
|||
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OpenAiCompatibleSettingsContent {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<provider::open_ai_compatible::AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct VercelSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
|
@ -114,6 +127,12 @@ pub struct GoogleSettingsContent {
|
|||
pub available_models: Option<Vec<provider::google::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct XAiSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub available_models: Option<Vec<provider::x_ai::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct ZedDotDevSettingsContent {
|
||||
available_models: Option<Vec<cloud::AvailableModel>>,
|
||||
|
@ -219,6 +238,19 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
openai.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// OpenAI Compatible
|
||||
if let Some(openai_compatible) = value.openai_compatible.clone() {
|
||||
for (id, openai_compatible_settings) in openai_compatible {
|
||||
settings.openai_compatible.insert(
|
||||
id,
|
||||
OpenAiCompatibleSettings {
|
||||
api_url: openai_compatible_settings.api_url,
|
||||
available_models: openai_compatible_settings.available_models,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Vercel
|
||||
let vercel = value.vercel.clone();
|
||||
merge(
|
||||
|
@ -230,6 +262,18 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
vercel.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// XAI
|
||||
let x_ai = value.x_ai.clone();
|
||||
merge(
|
||||
&mut settings.x_ai.api_url,
|
||||
x_ai.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
merge(
|
||||
&mut settings.x_ai.available_models,
|
||||
x_ai.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// ZedDotDev
|
||||
merge(
|
||||
&mut settings.zed_dot_dev.available_models,
|
||||
value
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue