diff --git a/Cargo.lock b/Cargo.lock index e2d86576c3..15a28016c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9094,6 +9094,7 @@ dependencies = [ "util", "vercel", "workspace-hack", + "x_ai", "zed_llm_client", ] @@ -19840,6 +19841,17 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d" +[[package]] +name = "x_ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "schemars", + "serde", + "strum 0.27.1", + "workspace-hack", +] + [[package]] name = "xattr" version = "0.2.3" diff --git a/Cargo.toml b/Cargo.toml index 0e4cd1504f..afb47c006e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -179,6 +179,7 @@ members = [ "crates/welcome", "crates/workspace", "crates/worktree", + "crates/x_ai", "crates/zed", "crates/zed_actions", "crates/zeta", @@ -394,6 +395,7 @@ web_search_providers = { path = "crates/web_search_providers" } welcome = { path = "crates/welcome" } workspace = { path = "crates/workspace" } worktree = { path = "crates/worktree" } +x_ai = { path = "crates/x_ai" } zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } zeta = { path = "crates/zeta" } diff --git a/assets/icons/ai_x_ai.svg b/assets/icons/ai_x_ai.svg new file mode 100644 index 0000000000..289525c8ef --- /dev/null +++ b/assets/icons/ai_x_ai.svg @@ -0,0 +1,3 @@ + + + diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index 3c24ee59f6..b2ec768435 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -21,6 +21,7 @@ pub enum IconName { AiOpenAi, AiOpenRouter, AiVZero, + AiXAi, AiZed, ArrowCircle, ArrowDown, diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 0f248edd57..5d158e84f4 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -43,6 +43,7 @@ ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } open_router = { workspace = true, features = ["schemars"] } vercel = { workspace = true, features = ["schemars"] } +x_ai = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true proto.workspace = true release_channel.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index c7324732c9..192f5a5fae 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -20,6 +20,7 @@ use crate::provider::ollama::OllamaLanguageModelProvider; use crate::provider::open_ai::OpenAiLanguageModelProvider; use crate::provider::open_router::OpenRouterLanguageModelProvider; use crate::provider::vercel::VercelLanguageModelProvider; +use crate::provider::x_ai::XAiLanguageModelProvider; pub use crate::settings::*; pub fn init(user_store: Entity, client: Arc, cx: &mut App) { @@ -81,5 +82,6 @@ fn register_language_model_providers( VercelLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx); registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx); } diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index 6bc93bd366..c717be7c90 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -10,3 +10,4 @@ pub mod ollama; pub mod open_ai; pub mod open_router; pub mod vercel; +pub mod x_ai; diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index c46135ff3e..5a6acc4329 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -376,7 +376,7 @@ impl LanguageModel for OpenRouterLanguageModel { fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { let model_id = self.model.id().trim().to_lowercase(); - if model_id.contains("gemini") { + if model_id.contains("gemini") || model_id.contains("grok-4") { LanguageModelToolSchemaFormat::JsonSchemaSubset } else { LanguageModelToolSchemaFormat::JsonSchema diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs new file mode 100644 index 0000000000..5f6034571b --- /dev/null +++ b/crates/language_models/src/provider/x_ai.rs @@ -0,0 +1,571 @@ +use anyhow::{Context as _, Result, anyhow}; +use collections::BTreeMap; +use credentials_provider::CredentialsProvider; +use futures::{FutureExt, StreamExt, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use http_client::HttpClient; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role, +}; +use menu; +use open_ai::ResponseStreamEvent; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; +use strum::IntoEnumIterator; +use x_ai::Model; + +use ui::{ElevationIndex, List, Tooltip, prelude::*}; +use ui_input::SingleLineInput; +use util::ResultExt; + +use crate::{AllLanguageModelSettings, ui::InstructionListItem}; + +const PROVIDER_ID: &str = "x_ai"; +const PROVIDER_NAME: &str = "xAI"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct XAiSettings { + pub api_url: String, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option, + pub max_tokens: u64, + pub max_output_tokens: Option, + pub max_completion_tokens: Option, +} + +pub struct XAiLanguageModelProvider { + http_client: Arc, + state: gpui::Entity, +} + +pub struct State { + api_key: Option, + api_key_from_env: bool, + _subscription: Subscription, +} + +const XAI_API_KEY_VAR: &str = "XAI_API_KEY"; + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + cx.spawn(async move |this, cx| { + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + cx.spawn(async move |this, cx| { + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + let credentials_provider = ::global(cx); + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + cx.spawn(async move |this, cx| { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, + false, + ) + }; + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) + } +} + +impl XAiLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| State { + api_key: None, + api_key_from_env: false, + _subscription: cx.observe_global::(|_this: &mut State, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: x_ai::Model) -> Arc { + Arc::new(XAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } +} + +impl LanguageModelProviderState for XAiLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for XAiLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiXAi + } + + fn default_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(x_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(x_ai::Model::default_fast())) + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models = BTreeMap::default(); + + for model in x_ai::Model::iter() { + if !matches!(model, x_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + for model in &AllLanguageModelSettings::get_global(cx) + .x_ai + .available_models + { + models.insert( + model.name.clone(), + x_ai::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, + max_completion_tokens: model.max_completion_tokens, + }, + ); + } + + models + .into_values() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct XAiLanguageModel { + id: LanguageModelId, + model: x_ai::Model, + state: gpui::Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl XAiLanguageModel { + fn stream_completion( + &self, + request: open_ai::Request, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> + { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).x_ai; + let api_url = if settings.api_url.is_empty() { + x_ai::XAI_API_URL.to_string() + } else { + settings.api_url.clone() + }; + (state.api_key.clone(), api_url) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let api_key = api_key.context("Missing xAI API Key")?; + let request = + open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for XAiLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn supports_tools(&self) -> bool { + self.model.supports_tool() + } + + fn supports_images(&self) -> bool { + self.model.supports_images() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => true, + } + } + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + let model_id = self.model.id().trim().to_lowercase(); + if model_id.eq(x_ai::Model::Grok4.id()) { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } else { + LanguageModelToolSchemaFormat::JsonSchema + } + } + + fn telemetry_id(&self) -> String { + format!("x_ai/{}", self.model.id()) + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + count_xai_tokens(request, self.model.clone(), cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + let request = crate::provider::open_ai::into_open_ai( + request, + self.model.id(), + self.model.supports_parallel_tool_calls(), + self.max_output_tokens(), + ); + let completions = self.stream_completion(request, cx); + async move { + let mapper = crate::provider::open_ai::OpenAiEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) + } + .boxed() + } +} + +pub fn count_xai_tokens( + request: LanguageModelRequest, + model: Model, + cx: &App, +) -> BoxFuture<'static, Result> { + cx.background_spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + let model_name = if model.max_token_count() >= 100_000 { + "gpt-4o" + } else { + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model_name, &messages).map(|tokens| tokens as u64) + }) + .boxed() +} + +struct ConfigurationView { + api_key_editor: Entity, + state: gpui::Entity, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| { + SingleLineInput::new( + window, + cx, + "xai-0000000000000000000000000000000000000000000000000", + ) + .label("API key") + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self + .api_key_editor + .read(cx) + .editor() + .read(cx) + .text(cx) + .trim() + .to_string(); + + // Don't proceed if no API key is provided and we're not authenticated + if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + return; + } + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor.update(cx, |input, cx| { + input.editor.update(cx, |editor, cx| { + editor.set_text("", window, cx); + }); + }); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state.update(cx, |state, cx| state.reset_api_key(cx))?.await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_from_env; + + let api_key_section = if self.should_render_editor(cx) { + v_flex() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new("To use Zed's agent with xAI, you need to add an API key. Follow these steps:")) + .child( + List::new() + .child(InstructionListItem::new( + "Create one by visiting", + Some("xAI console"), + Some("https://console.x.ai/team/default/api-keys"), + )) + .child(InstructionListItem::text_only( + "Paste your API key below and hit enter to start using the agent", + )), + ) + .child(self.api_key_editor.clone()) + .child( + Label::new(format!( + "You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed." + )) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + Label::new("Note that xAI is a custom OpenAI-compatible provider.") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any() + } else { + h_flex() + .mt_1() + .p_1() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().background) + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("API key set in {XAI_API_KEY_VAR} environment variable.") + } else { + "API key configured.".to_string() + })), + ) + .child( + Button::new("reset-api-key", "Reset API Key") + .label_size(LabelSize::Small) + .icon(IconName::Undo) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .layer(ElevationIndex::ModalSurface) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable."))) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + ) + .into_any() + }; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials…")).into_any() + } else { + v_flex().size_full().child(api_key_section).into_any() + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index f96a2c0a66..dafbb62910 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -17,6 +17,7 @@ use crate::provider::{ open_ai::OpenAiSettings, open_router::OpenRouterSettings, vercel::VercelSettings, + x_ai::XAiSettings, }; /// Initializes the language model settings. @@ -28,33 +29,33 @@ pub fn init(cx: &mut App) { pub struct AllLanguageModelSettings { pub anthropic: AnthropicSettings, pub bedrock: AmazonBedrockSettings, - pub ollama: OllamaSettings, - pub openai: OpenAiSettings, - pub open_router: OpenRouterSettings, - pub zed_dot_dev: ZedDotDevSettings, - pub google: GoogleSettings, - pub vercel: VercelSettings, - - pub lmstudio: LmStudioSettings, pub deepseek: DeepSeekSettings, + pub google: GoogleSettings, + pub lmstudio: LmStudioSettings, pub mistral: MistralSettings, + pub ollama: OllamaSettings, + pub open_router: OpenRouterSettings, + pub openai: OpenAiSettings, + pub vercel: VercelSettings, + pub x_ai: XAiSettings, + pub zed_dot_dev: ZedDotDevSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct AllLanguageModelSettingsContent { pub anthropic: Option, pub bedrock: Option, - pub ollama: Option, + pub deepseek: Option, + pub google: Option, pub lmstudio: Option, - pub openai: Option, + pub mistral: Option, + pub ollama: Option, pub open_router: Option, + pub openai: Option, + pub vercel: Option, + pub x_ai: Option, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, - pub google: Option, - pub deepseek: Option, - pub vercel: Option, - - pub mistral: Option, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -114,6 +115,12 @@ pub struct GoogleSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct XAiSettingsContent { + pub api_url: Option, + pub available_models: Option>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct ZedDotDevSettingsContent { available_models: Option>, @@ -230,6 +237,18 @@ impl settings::Settings for AllLanguageModelSettings { vercel.as_ref().and_then(|s| s.available_models.clone()), ); + // XAI + let x_ai = value.x_ai.clone(); + merge( + &mut settings.x_ai.api_url, + x_ai.as_ref().and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.x_ai.available_models, + x_ai.as_ref().and_then(|s| s.available_models.clone()), + ); + + // ZedDotDev merge( &mut settings.zed_dot_dev.available_models, value diff --git a/crates/x_ai/Cargo.toml b/crates/x_ai/Cargo.toml new file mode 100644 index 0000000000..7ca0ca0939 --- /dev/null +++ b/crates/x_ai/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "x_ai" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/x_ai.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +strum.workspace = true +workspace-hack.workspace = true diff --git a/crates/x_ai/LICENSE-GPL b/crates/x_ai/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/x_ai/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs new file mode 100644 index 0000000000..ac116b2f8f --- /dev/null +++ b/crates/x_ai/src/x_ai.rs @@ -0,0 +1,126 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use strum::EnumIter; + +pub const XAI_API_URL: &str = "https://api.x.ai/v1"; + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] +pub enum Model { + #[serde(rename = "grok-2-vision-latest")] + Grok2Vision, + #[default] + #[serde(rename = "grok-3-latest")] + Grok3, + #[serde(rename = "grok-3-mini-latest")] + Grok3Mini, + #[serde(rename = "grok-3-fast-latest")] + Grok3Fast, + #[serde(rename = "grok-3-mini-fast-latest")] + Grok3MiniFast, + #[serde(rename = "grok-4-latest")] + Grok4, + #[serde(rename = "custom")] + Custom { + name: String, + /// The name displayed in the UI, such as in the assistant panel model dropdown menu. + display_name: Option, + max_tokens: u64, + max_output_tokens: Option, + max_completion_tokens: Option, + }, +} + +impl Model { + pub fn default_fast() -> Self { + Self::Grok3Fast + } + + pub fn from_id(id: &str) -> Result { + match id { + "grok-2-vision" => Ok(Self::Grok2Vision), + "grok-3" => Ok(Self::Grok3), + "grok-3-mini" => Ok(Self::Grok3Mini), + "grok-3-fast" => Ok(Self::Grok3Fast), + "grok-3-mini-fast" => Ok(Self::Grok3MiniFast), + _ => anyhow::bail!("invalid model id '{id}'"), + } + } + + pub fn id(&self) -> &str { + match self { + Self::Grok2Vision => "grok-2-vision", + Self::Grok3 => "grok-3", + Self::Grok3Mini => "grok-3-mini", + Self::Grok3Fast => "grok-3-fast", + Self::Grok3MiniFast => "grok-3-mini-fast", + Self::Grok4 => "grok-4", + Self::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Grok2Vision => "Grok 2 Vision", + Self::Grok3 => "Grok 3", + Self::Grok3Mini => "Grok 3 Mini", + Self::Grok3Fast => "Grok 3 Fast", + Self::Grok3MiniFast => "Grok 3 Mini Fast", + Self::Grok4 => "Grok 4", + Self::Custom { + name, display_name, .. + } => display_name.as_ref().unwrap_or(name), + } + } + + pub fn max_token_count(&self) -> u64 { + match self { + Self::Grok3 | Self::Grok3Mini | Self::Grok3Fast | Self::Grok3MiniFast => 131_072, + Self::Grok4 => 256_000, + Self::Grok2Vision => 8_192, + Self::Custom { max_tokens, .. } => *max_tokens, + } + } + + pub fn max_output_tokens(&self) -> Option { + match self { + Self::Grok3 | Self::Grok3Mini | Self::Grok3Fast | Self::Grok3MiniFast => Some(8_192), + Self::Grok4 => Some(64_000), + Self::Grok2Vision => Some(4_096), + Self::Custom { + max_output_tokens, .. + } => *max_output_tokens, + } + } + + pub fn supports_parallel_tool_calls(&self) -> bool { + match self { + Self::Grok2Vision + | Self::Grok3 + | Self::Grok3Mini + | Self::Grok3Fast + | Self::Grok3MiniFast + | Self::Grok4 => true, + Model::Custom { .. } => false, + } + } + + pub fn supports_tool(&self) -> bool { + match self { + Self::Grok2Vision + | Self::Grok3 + | Self::Grok3Mini + | Self::Grok3Fast + | Self::Grok3MiniFast + | Self::Grok4 => true, + Model::Custom { .. } => false, + } + } + + pub fn supports_images(&self) -> bool { + match self { + Self::Grok2Vision => true, + _ => false, + } + } +} diff --git a/docs/src/ai/configuration.md b/docs/src/ai/configuration.md index ade1ae672f..56eb4ab76c 100644 --- a/docs/src/ai/configuration.md +++ b/docs/src/ai/configuration.md @@ -23,6 +23,8 @@ Here's an overview of the supported providers and tool call support: | [OpenAI](#openai) | ✅ | | [OpenAI API Compatible](#openai-api-compatible) | 🚫 | | [OpenRouter](#openrouter) | ✅ | +| [Vercel](#vercel-v0) | ✅ | +| [xAI](#xai) | ✅ | ## Use Your Own Keys {#use-your-own-keys} @@ -444,27 +446,30 @@ Custom models will be listed in the model dropdown in the Agent Panel. Zed supports using OpenAI compatible APIs by specifying a custom `endpoint` and `available_models` for the OpenAI provider. -You can add a custom API URL for OpenAI either via the UI or by editing your `settings.json`. -Here are a few model examples you can plug in by using this feature: +Zed supports using OpenAI compatible APIs by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models. -#### X.ai Grok +To configure a compatible API, you can add a custom API URL for OpenAI either via the UI or by editing your `settings.json`. For example, to connect to [Together AI](https://www.together.ai/): -Example configuration for using X.ai Grok with Zed: +1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys). +2. Add the following to your `settings.json`: ```json +{ "language_models": { "openai": { - "api_url": "https://api.x.ai/v1", + "api_url": "https://api.together.xyz/v1", + "api_key": "YOUR_TOGETHER_AI_API_KEY", "available_models": [ { - "name": "grok-beta", - "display_name": "X.ai Grok (Beta)", - "max_tokens": 131072 + "name": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "display_name": "Together Mixtral 8x7B", + "max_tokens": 32768, + "supports_tools": true } - ], - "version": "1" - }, + ] + } } +} ``` ### OpenRouter {#openrouter} @@ -525,7 +530,9 @@ You can find available models and their specifications on the [OpenRouter models Custom models will be listed in the model dropdown in the Agent Panel. -### Vercel v0 +### Vercel v0 {#vercel-v0} + +> ✅ Supports tool use [Vercel v0](https://vercel.com/docs/v0/api) is an expert model for generating full-stack apps, with framework-aware completions optimized for modern stacks like Next.js and Vercel. It supports text and image inputs and provides fast streaming responses. @@ -537,6 +544,49 @@ Once you have it, paste it directly into the Vercel provider section in the pane You should then find it as `v0-1.5-md` in the model dropdown in the Agent Panel. +### xAI {#xai} + +> ✅ Supports tool use + +Zed has first-class support for [xAI](https://x.ai/) models. You can use your own API key to access Grok models. + +1. [Create an API key in the xAI Console](https://console.x.ai/team/default/api-keys) +2. Open the settings view (`agent: open configuration`) and go to the **xAI** section +3. Enter your xAI API key + +The xAI API key will be saved in your keychain. Zed will also use the `XAI_API_KEY` environment variable if it's defined. + +> **Note:** While the xAI API is OpenAI-compatible, Zed has first-class support for it as a dedicated provider. For the best experience, we recommend using the dedicated `x_ai` provider configuration instead of the [OpenAI API Compatible](#openai-api-compatible) method. + +#### Custom Models {#xai-custom-models} + +The Zed agent comes pre-configured with common Grok models. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`: + +```json +{ + "language_models": { + "x_ai": { + "api_url": "https://api.x.ai/v1", + "available_models": [ + { + "name": "grok-1.5", + "display_name": "Grok 1.5", + "max_tokens": 131072, + "max_output_tokens": 8192 + }, + { + "name": "grok-1.5v", + "display_name": "Grok 1.5V (Vision)", + "max_tokens": 131072, + "max_output_tokens": 8192, + "supports_images": true + } + ] + } + } +} +``` + ## Advanced Configuration {#advanced-configuration} ### Custom Provider Endpoints {#custom-provider-endpoint}