diff --git a/Cargo.lock b/Cargo.lock
index ad6c40bcf2..c7297e6d59 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -231,6 +231,7 @@ dependencies = [
"jsonschema",
"language",
"language_model",
+ "language_models",
"languages",
"log",
"lsp",
@@ -269,6 +270,7 @@ dependencies = [
"time_format",
"tree-sitter-md",
"ui",
+ "ui_input",
"unindent",
"urlencoding",
"util",
@@ -9097,11 +9099,11 @@ dependencies = [
"client",
"collections",
"component",
+ "convert_case 0.8.0",
"copilot",
"credentials_provider",
"deepseek",
"editor",
- "fs",
"futures 0.3.31",
"google_ai",
"gpui",
diff --git a/assets/icons/ai_open_ai_compat.svg b/assets/icons/ai_open_ai_compat.svg
new file mode 100644
index 0000000000..f6557caac3
--- /dev/null
+++ b/assets/icons/ai_open_ai_compat.svg
@@ -0,0 +1,4 @@
+
diff --git a/assets/settings/default.json b/assets/settings/default.json
index 358871650b..309afaccf5 100644
--- a/assets/settings/default.json
+++ b/assets/settings/default.json
@@ -1712,6 +1712,7 @@
"openai": {
"api_url": "https://api.openai.com/v1"
},
+ "openai_compatible": {},
"open_router": {
"api_url": "https://openrouter.ai/api/v1"
},
diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs
index f8a7827615..1b3b022ab2 100644
--- a/crates/agent/src/thread.rs
+++ b/crates/agent/src/thread.rs
@@ -5490,7 +5490,7 @@ fn main() {{
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
- let provider = Arc::new(FakeLanguageModelProvider);
+ let provider = Arc::new(FakeLanguageModelProvider::default());
let model = provider.test_model();
let model: Arc = Arc::new(model);
diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml
index e55ae86fb7..33042c0ebd 100644
--- a/crates/agent_ui/Cargo.toml
+++ b/crates/agent_ui/Cargo.toml
@@ -53,6 +53,7 @@ itertools.workspace = true
jsonschema.workspace = true
language.workspace = true
language_model.workspace = true
+language_models.workspace = true
log.workspace = true
lsp.workspace = true
markdown.workspace = true
@@ -87,6 +88,7 @@ theme.workspace = true
time.workspace = true
time_format.workspace = true
ui.workspace = true
+ui_input.workspace = true
urlencoding.workspace = true
util.workspace = true
uuid.workspace = true
diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs
index bfed81f5b7..e27c318221 100644
--- a/crates/agent_ui/src/active_thread.rs
+++ b/crates/agent_ui/src/active_thread.rs
@@ -3895,7 +3895,7 @@ mod tests {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(
Some(ConfiguredModel {
- provider: Arc::new(FakeLanguageModelProvider),
+ provider: Arc::new(FakeLanguageModelProvider::default()),
model,
}),
cx,
@@ -3979,7 +3979,7 @@ mod tests {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(
Some(ConfiguredModel {
- provider: Arc::new(FakeLanguageModelProvider),
+ provider: Arc::new(FakeLanguageModelProvider::default()),
model: model.clone(),
}),
cx,
diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs
index b5ad6ba37c..334c5ee6dc 100644
--- a/crates/agent_ui/src/agent_configuration.rs
+++ b/crates/agent_ui/src/agent_configuration.rs
@@ -1,3 +1,4 @@
+mod add_llm_provider_modal;
mod configure_context_server_modal;
mod manage_profiles_modal;
mod tool_picker;
@@ -37,7 +38,10 @@ use zed_actions::ExtensionCategoryFilter;
pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
pub(crate) use manage_profiles_modal::ManageProfilesModal;
-use crate::AddContextServer;
+use crate::{
+ AddContextServer,
+ agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider},
+};
pub struct AgentConfiguration {
fs: Arc,
@@ -304,16 +308,55 @@ impl AgentConfiguration {
v_flex()
.child(
- v_flex()
+ h_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.pb_0()
.mb_2p5()
- .gap_0p5()
- .child(Headline::new("LLM Providers"))
+ .items_start()
+ .justify_between()
.child(
- Label::new("Add at least one provider to use AI-powered features.")
- .color(Color::Muted),
+ v_flex()
+ .gap_0p5()
+ .child(Headline::new("LLM Providers"))
+ .child(
+ Label::new("Add at least one provider to use AI-powered features.")
+ .color(Color::Muted),
+ ),
+ )
+ .child(
+ PopoverMenu::new("add-provider-popover")
+ .trigger(
+ Button::new("add-provider", "Add Provider")
+ .icon_position(IconPosition::Start)
+ .icon(IconName::Plus)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .label_size(LabelSize::Small),
+ )
+ .anchor(gpui::Corner::TopRight)
+ .menu({
+ let workspace = self.workspace.clone();
+ move |window, cx| {
+ Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
+ menu.header("Compatible APIs").entry("OpenAI", None, {
+ let workspace = workspace.clone();
+ move |window, cx| {
+ workspace
+ .update(cx, |workspace, cx| {
+ AddLlmProviderModal::toggle(
+ LlmCompatibleProvider::OpenAi,
+ workspace,
+ window,
+ cx,
+ );
+ })
+ .log_err();
+ }
+ })
+ }))
+ }
+ }),
),
)
.child(
diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs
new file mode 100644
index 0000000000..94b32d156b
--- /dev/null
+++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs
@@ -0,0 +1,639 @@
+use std::sync::Arc;
+
+use anyhow::Result;
+use collections::HashSet;
+use fs::Fs;
+use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
+use language_model::LanguageModelRegistry;
+use language_models::{
+ AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
+ provider::open_ai_compatible::AvailableModel,
+};
+use settings::update_settings_file;
+use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
+use ui_input::SingleLineInput;
+use workspace::{ModalView, Workspace};
+
+#[derive(Clone, Copy)]
+pub enum LlmCompatibleProvider {
+ OpenAi,
+}
+
+impl LlmCompatibleProvider {
+ fn name(&self) -> &'static str {
+ match self {
+ LlmCompatibleProvider::OpenAi => "OpenAI",
+ }
+ }
+
+ fn api_url(&self) -> &'static str {
+ match self {
+ LlmCompatibleProvider::OpenAi => "https://api.openai.com/v1",
+ }
+ }
+}
+
+struct AddLlmProviderInput {
+ provider_name: Entity,
+ api_url: Entity,
+ api_key: Entity,
+ models: Vec,
+}
+
+impl AddLlmProviderInput {
+ fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
+ let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
+ let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
+ let api_key = single_line_input(
+ "API Key",
+ "000000000000000000000000000000000000000000000000",
+ None,
+ window,
+ cx,
+ );
+
+ Self {
+ provider_name,
+ api_url,
+ api_key,
+ models: vec![ModelInput::new(window, cx)],
+ }
+ }
+
+ fn add_model(&mut self, window: &mut Window, cx: &mut App) {
+ self.models.push(ModelInput::new(window, cx));
+ }
+
+ fn remove_model(&mut self, index: usize) {
+ self.models.remove(index);
+ }
+}
+
+struct ModelInput {
+ name: Entity,
+ max_completion_tokens: Entity,
+ max_output_tokens: Entity,
+ max_tokens: Entity,
+}
+
+impl ModelInput {
+ fn new(window: &mut Window, cx: &mut App) -> Self {
+ let model_name = single_line_input(
+ "Model Name",
+ "e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
+ None,
+ window,
+ cx,
+ );
+ let max_completion_tokens = single_line_input(
+ "Max Completion Tokens",
+ "200000",
+ Some("200000"),
+ window,
+ cx,
+ );
+ let max_output_tokens = single_line_input(
+ "Max Output Tokens",
+ "Max Output Tokens",
+ Some("32000"),
+ window,
+ cx,
+ );
+ let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
+ Self {
+ name: model_name,
+ max_completion_tokens,
+ max_output_tokens,
+ max_tokens,
+ }
+ }
+
+ fn parse(&self, cx: &App) -> Result {
+ let name = self.name.read(cx).text(cx);
+ if name.is_empty() {
+ return Err(SharedString::from("Model Name cannot be empty"));
+ }
+ Ok(AvailableModel {
+ name,
+ display_name: None,
+ max_completion_tokens: Some(
+ self.max_completion_tokens
+ .read(cx)
+ .text(cx)
+ .parse::()
+ .map_err(|_| SharedString::from("Max Completion Tokens must be a number"))?,
+ ),
+ max_output_tokens: Some(
+ self.max_output_tokens
+ .read(cx)
+ .text(cx)
+ .parse::()
+ .map_err(|_| SharedString::from("Max Output Tokens must be a number"))?,
+ ),
+ max_tokens: self
+ .max_tokens
+ .read(cx)
+ .text(cx)
+ .parse::()
+ .map_err(|_| SharedString::from("Max Tokens must be a number"))?,
+ })
+ }
+}
+
+fn single_line_input(
+ label: impl Into,
+ placeholder: impl Into,
+ text: Option<&str>,
+ window: &mut Window,
+ cx: &mut App,
+) -> Entity {
+ cx.new(|cx| {
+ let input = SingleLineInput::new(window, cx, placeholder).label(label);
+ if let Some(text) = text {
+ input
+ .editor()
+ .update(cx, |editor, cx| editor.set_text(text, window, cx));
+ }
+ input
+ })
+}
+
+fn save_provider_to_settings(
+ input: &AddLlmProviderInput,
+ cx: &mut App,
+) -> Task> {
+ let provider_name: Arc = input.provider_name.read(cx).text(cx).into();
+ if provider_name.is_empty() {
+ return Task::ready(Err("Provider Name cannot be empty".into()));
+ }
+
+ if LanguageModelRegistry::read_global(cx)
+ .providers()
+ .iter()
+ .any(|provider| {
+ provider.id().0.as_ref() == provider_name.as_ref()
+ || provider.name().0.as_ref() == provider_name.as_ref()
+ })
+ {
+ return Task::ready(Err(
+ "Provider Name is already taken by another provider".into()
+ ));
+ }
+
+ let api_url = input.api_url.read(cx).text(cx);
+ if api_url.is_empty() {
+ return Task::ready(Err("API URL cannot be empty".into()));
+ }
+
+ let api_key = input.api_key.read(cx).text(cx);
+ if api_key.is_empty() {
+ return Task::ready(Err("API Key cannot be empty".into()));
+ }
+
+ let mut models = Vec::new();
+ let mut model_names: HashSet = HashSet::default();
+ for model in &input.models {
+ match model.parse(cx) {
+ Ok(model) => {
+ if !model_names.insert(model.name.clone()) {
+ return Task::ready(Err("Model Names must be unique".into()));
+ }
+ models.push(model)
+ }
+ Err(err) => return Task::ready(Err(err)),
+ }
+ }
+
+ let fs = ::global(cx);
+ let task = cx.write_credentials(&api_url, "Bearer", api_key.as_bytes());
+ cx.spawn(async move |cx| {
+ task.await
+ .map_err(|_| "Failed to write API key to keychain")?;
+ cx.update(|cx| {
+ update_settings_file::(fs, cx, |settings, _cx| {
+ settings.openai_compatible.get_or_insert_default().insert(
+ provider_name,
+ OpenAiCompatibleSettingsContent {
+ api_url,
+ available_models: models,
+ },
+ );
+ });
+ })
+ .ok();
+ Ok(())
+ })
+}
+
+pub struct AddLlmProviderModal {
+ provider: LlmCompatibleProvider,
+ input: AddLlmProviderInput,
+ focus_handle: FocusHandle,
+ last_error: Option,
+}
+
+impl AddLlmProviderModal {
+ pub fn toggle(
+ provider: LlmCompatibleProvider,
+ workspace: &mut Workspace,
+ window: &mut Window,
+ cx: &mut Context,
+ ) {
+ workspace.toggle_modal(window, cx, |window, cx| Self::new(provider, window, cx));
+ }
+
+ fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut Context) -> Self {
+ Self {
+ input: AddLlmProviderInput::new(provider, window, cx),
+ provider,
+ last_error: None,
+ focus_handle: cx.focus_handle(),
+ }
+ }
+
+ fn confirm(&mut self, _: &menu::Confirm, _: &mut Window, cx: &mut Context) {
+ let task = save_provider_to_settings(&self.input, cx);
+ cx.spawn(async move |this, cx| {
+ let result = task.await;
+ this.update(cx, |this, cx| match result {
+ Ok(_) => {
+ cx.emit(DismissEvent);
+ }
+ Err(error) => {
+ this.last_error = Some(error);
+ cx.notify();
+ }
+ })
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) {
+ cx.emit(DismissEvent);
+ }
+
+ fn render_section(&self) -> Section {
+ Section::new()
+ .child(self.input.provider_name.clone())
+ .child(self.input.api_url.clone())
+ .child(self.input.api_key.clone())
+ }
+
+ fn render_model_section(&self, cx: &mut Context) -> Section {
+ Section::new().child(
+ v_flex()
+ .gap_2()
+ .child(
+ h_flex()
+ .justify_between()
+ .child(Label::new("Models").size(LabelSize::Small))
+ .child(
+ Button::new("add-model", "Add Model")
+ .icon(IconName::Plus)
+ .icon_position(IconPosition::Start)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Muted)
+ .label_size(LabelSize::Small)
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.input.add_model(window, cx);
+ cx.notify();
+ })),
+ ),
+ )
+ .children(
+ self.input
+ .models
+ .iter()
+ .enumerate()
+ .map(|(ix, _)| self.render_model(ix, cx)),
+ ),
+ )
+ }
+
+ fn render_model(&self, ix: usize, cx: &mut Context) -> impl IntoElement + use<> {
+ let has_more_than_one_model = self.input.models.len() > 1;
+ let model = &self.input.models[ix];
+
+ v_flex()
+ .p_2()
+ .gap_2()
+ .rounded_sm()
+ .border_1()
+ .border_dashed()
+ .border_color(cx.theme().colors().border.opacity(0.6))
+ .bg(cx.theme().colors().element_active.opacity(0.15))
+ .child(model.name.clone())
+ .child(
+ h_flex()
+ .gap_2()
+ .child(model.max_completion_tokens.clone())
+ .child(model.max_output_tokens.clone()),
+ )
+ .child(model.max_tokens.clone())
+ .when(has_more_than_one_model, |this| {
+ this.child(
+ Button::new(("remove-model", ix), "Remove Model")
+ .icon(IconName::Trash)
+ .icon_position(IconPosition::Start)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Muted)
+ .label_size(LabelSize::Small)
+ .style(ButtonStyle::Outlined)
+ .full_width()
+ .on_click(cx.listener(move |this, _, _window, cx| {
+ this.input.remove_model(ix);
+ cx.notify();
+ })),
+ )
+ })
+ }
+}
+
+impl EventEmitter for AddLlmProviderModal {}
+
+impl Focusable for AddLlmProviderModal {
+ fn focus_handle(&self, _cx: &App) -> FocusHandle {
+ self.focus_handle.clone()
+ }
+}
+
+impl ModalView for AddLlmProviderModal {}
+
+impl Render for AddLlmProviderModal {
+ fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context) -> impl IntoElement {
+ let focus_handle = self.focus_handle(cx);
+
+ div()
+ .id("add-llm-provider-modal")
+ .key_context("AddLlmProviderModal")
+ .w(rems(34.))
+ .elevation_3(cx)
+ .on_action(cx.listener(Self::cancel))
+ .capture_any_mouse_down(cx.listener(|this, _, window, cx| {
+ this.focus_handle(cx).focus(window);
+ }))
+ .child(
+ Modal::new("configure-context-server", None)
+ .header(ModalHeader::new().headline("Add LLM Provider").description(
+ match self.provider {
+ LlmCompatibleProvider::OpenAi => {
+ "This provider will use an OpenAI compatible API."
+ }
+ },
+ ))
+ .when_some(self.last_error.clone(), |this, error| {
+ this.section(
+ Section::new().child(
+ Banner::new()
+ .severity(ui::Severity::Warning)
+ .child(div().text_xs().child(error)),
+ ),
+ )
+ })
+ .child(
+ v_flex()
+ .id("modal_content")
+ .max_h_128()
+ .overflow_y_scroll()
+ .gap_2()
+ .child(self.render_section())
+ .child(self.render_model_section(cx)),
+ )
+ .footer(
+ ModalFooter::new().end_slot(
+ h_flex()
+ .gap_1()
+ .child(
+ Button::new("cancel", "Cancel")
+ .key_binding(
+ KeyBinding::for_action_in(
+ &menu::Cancel,
+ &focus_handle,
+ window,
+ cx,
+ )
+ .map(|kb| kb.size(rems_from_px(12.))),
+ )
+ .on_click(cx.listener(|this, _event, window, cx| {
+ this.cancel(&menu::Cancel, window, cx)
+ })),
+ )
+ .child(
+ Button::new("save-server", "Save Provider")
+ .key_binding(
+ KeyBinding::for_action_in(
+ &menu::Confirm,
+ &focus_handle,
+ window,
+ cx,
+ )
+ .map(|kb| kb.size(rems_from_px(12.))),
+ )
+ .on_click(cx.listener(|this, _event, window, cx| {
+ this.confirm(&menu::Confirm, window, cx)
+ })),
+ ),
+ ),
+ ),
+ )
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use editor::EditorSettings;
+ use fs::FakeFs;
+ use gpui::{TestAppContext, VisualTestContext};
+ use language::language_settings;
+ use language_model::{
+ LanguageModelProviderId, LanguageModelProviderName,
+ fake_provider::FakeLanguageModelProvider,
+ };
+ use project::Project;
+ use settings::{Settings as _, SettingsStore};
+ use util::path;
+
+ #[gpui::test]
+ async fn test_save_provider_invalid_inputs(cx: &mut TestAppContext) {
+ let cx = setup_test(cx).await;
+
+ assert_eq!(
+ save_provider_validation_errors("", "someurl", "somekey", vec![], cx,).await,
+ Some("Provider Name cannot be empty".into())
+ );
+
+ assert_eq!(
+ save_provider_validation_errors("someprovider", "", "somekey", vec![], cx,).await,
+ Some("API URL cannot be empty".into())
+ );
+
+ assert_eq!(
+ save_provider_validation_errors("someprovider", "someurl", "", vec![], cx,).await,
+ Some("API Key cannot be empty".into())
+ );
+
+ assert_eq!(
+ save_provider_validation_errors(
+ "someprovider",
+ "someurl",
+ "somekey",
+ vec![("", "200000", "200000", "32000")],
+ cx,
+ )
+ .await,
+ Some("Model Name cannot be empty".into())
+ );
+
+ assert_eq!(
+ save_provider_validation_errors(
+ "someprovider",
+ "someurl",
+ "somekey",
+ vec![("somemodel", "abc", "200000", "32000")],
+ cx,
+ )
+ .await,
+ Some("Max Tokens must be a number".into())
+ );
+
+ assert_eq!(
+ save_provider_validation_errors(
+ "someprovider",
+ "someurl",
+ "somekey",
+ vec![("somemodel", "200000", "abc", "32000")],
+ cx,
+ )
+ .await,
+ Some("Max Completion Tokens must be a number".into())
+ );
+
+ assert_eq!(
+ save_provider_validation_errors(
+ "someprovider",
+ "someurl",
+ "somekey",
+ vec![("somemodel", "200000", "200000", "abc")],
+ cx,
+ )
+ .await,
+ Some("Max Output Tokens must be a number".into())
+ );
+
+ assert_eq!(
+ save_provider_validation_errors(
+ "someprovider",
+ "someurl",
+ "somekey",
+ vec![
+ ("somemodel", "200000", "200000", "32000"),
+ ("somemodel", "200000", "200000", "32000"),
+ ],
+ cx,
+ )
+ .await,
+ Some("Model Names must be unique".into())
+ );
+ }
+
+ #[gpui::test]
+ async fn test_save_provider_name_conflict(cx: &mut TestAppContext) {
+ let cx = setup_test(cx).await;
+
+ cx.update(|_window, cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry.register_provider(
+ FakeLanguageModelProvider::new(
+ LanguageModelProviderId::new("someprovider"),
+ LanguageModelProviderName::new("Some Provider"),
+ ),
+ cx,
+ );
+ });
+ });
+
+ assert_eq!(
+ save_provider_validation_errors(
+ "someprovider",
+ "someurl",
+ "someapikey",
+ vec![("somemodel", "200000", "200000", "32000")],
+ cx,
+ )
+ .await,
+ Some("Provider Name is already taken by another provider".into())
+ );
+ }
+
+ async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
+ cx.update(|cx| {
+ let store = SettingsStore::test(cx);
+ cx.set_global(store);
+ workspace::init_settings(cx);
+ Project::init_settings(cx);
+ theme::init(theme::LoadThemes::JustBase, cx);
+ language_settings::init(cx);
+ EditorSettings::register(cx);
+ language_model::init_settings(cx);
+ language_models::init_settings(cx);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ cx.update(|cx| ::set_global(fs.clone(), cx));
+ let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
+ let (_, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ cx
+ }
+
+ async fn save_provider_validation_errors(
+ provider_name: &str,
+ api_url: &str,
+ api_key: &str,
+ models: Vec<(&str, &str, &str, &str)>,
+ cx: &mut VisualTestContext,
+ ) -> Option {
+ fn set_text(
+ input: &Entity,
+ text: &str,
+ window: &mut Window,
+ cx: &mut App,
+ ) {
+ input.update(cx, |input, cx| {
+ input.editor().update(cx, |editor, cx| {
+ editor.set_text(text, window, cx);
+ });
+ });
+ }
+
+ let task = cx.update(|window, cx| {
+ let mut input = AddLlmProviderInput::new(LlmCompatibleProvider::OpenAi, window, cx);
+ set_text(&input.provider_name, provider_name, window, cx);
+ set_text(&input.api_url, api_url, window, cx);
+ set_text(&input.api_key, api_key, window, cx);
+
+ for (i, (name, max_tokens, max_completion_tokens, max_output_tokens)) in
+ models.iter().enumerate()
+ {
+ if i >= input.models.len() {
+ input.models.push(ModelInput::new(window, cx));
+ }
+ let model = &mut input.models[i];
+ set_text(&model.name, name, window, cx);
+ set_text(&model.max_tokens, max_tokens, window, cx);
+ set_text(
+ &model.max_completion_tokens,
+ max_completion_tokens,
+ window,
+ cx,
+ );
+ set_text(&model.max_output_tokens, max_output_tokens, window, cx);
+ }
+ save_provider_to_settings(&input, cx)
+ });
+
+ task.await.err()
+ }
+}
diff --git a/crates/assistant_context/src/assistant_context_tests.rs b/crates/assistant_context/src/assistant_context_tests.rs
index dba3bfde61..f139d525d3 100644
--- a/crates/assistant_context/src/assistant_context_tests.rs
+++ b/crates/assistant_context/src/assistant_context_tests.rs
@@ -1323,7 +1323,7 @@ fn setup_context_editor_with_fake_model(
) -> (Entity, Arc) {
let registry = Arc::new(LanguageRegistry::test(cx.executor().clone()));
- let fake_provider = Arc::new(FakeLanguageModelProvider);
+ let fake_provider = Arc::new(FakeLanguageModelProvider::default());
let fake_model = Arc::new(fake_provider.test_model());
cx.update(|cx| {
diff --git a/crates/assistant_tools/src/project_notifications_tool.rs b/crates/assistant_tools/src/project_notifications_tool.rs
index ec315d9ab1..7567926dca 100644
--- a/crates/assistant_tools/src/project_notifications_tool.rs
+++ b/crates/assistant_tools/src/project_notifications_tool.rs
@@ -200,7 +200,7 @@ mod tests {
// Run the tool before any changes
let tool = Arc::new(ProjectNotificationsTool);
- let provider = Arc::new(FakeLanguageModelProvider);
+ let provider = Arc::new(FakeLanguageModelProvider::default());
let model: Arc = Arc::new(provider.test_model());
let request = Arc::new(LanguageModelRequest::default());
let tool_input = json!({});
diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs
index b85e5b517d..e7066ae151 100644
--- a/crates/icons/src/icons.rs
+++ b/crates/icons/src/icons.rs
@@ -20,6 +20,7 @@ pub enum IconName {
AiMistral,
AiOllama,
AiOpenAi,
+ AiOpenAiCompat,
AiOpenRouter,
AiVZero,
AiXAi,
diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs
index f5191016d8..d54db7554a 100644
--- a/crates/language_model/src/fake_provider.rs
+++ b/crates/language_model/src/fake_provider.rs
@@ -10,25 +10,21 @@ use http_client::Result;
use parking_lot::Mutex;
use std::sync::Arc;
-pub fn language_model_id() -> LanguageModelId {
- LanguageModelId::from("fake".to_string())
+#[derive(Clone)]
+pub struct FakeLanguageModelProvider {
+ id: LanguageModelProviderId,
+ name: LanguageModelProviderName,
}
-pub fn language_model_name() -> LanguageModelName {
- LanguageModelName::from("Fake".to_string())
+impl Default for FakeLanguageModelProvider {
+ fn default() -> Self {
+ Self {
+ id: LanguageModelProviderId::from("fake".to_string()),
+ name: LanguageModelProviderName::from("Fake".to_string()),
+ }
+ }
}
-pub fn provider_id() -> LanguageModelProviderId {
- LanguageModelProviderId::from("fake".to_string())
-}
-
-pub fn provider_name() -> LanguageModelProviderName {
- LanguageModelProviderName::from("Fake".to_string())
-}
-
-#[derive(Clone, Default)]
-pub struct FakeLanguageModelProvider;
-
impl LanguageModelProviderState for FakeLanguageModelProvider {
type ObservableEntity = ();
@@ -39,11 +35,11 @@ impl LanguageModelProviderState for FakeLanguageModelProvider {
impl LanguageModelProvider for FakeLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- provider_id()
+ self.id.clone()
}
fn name(&self) -> LanguageModelProviderName {
- provider_name()
+ self.name.clone()
}
fn default_model(&self, _cx: &App) -> Option> {
@@ -76,6 +72,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
}
impl FakeLanguageModelProvider {
+ pub fn new(id: LanguageModelProviderId, name: LanguageModelProviderName) -> Self {
+ Self { id, name }
+ }
+
pub fn test_model(&self) -> FakeLanguageModel {
FakeLanguageModel::default()
}
@@ -89,11 +89,22 @@ pub struct ToolUseRequest {
pub schema: serde_json::Value,
}
-#[derive(Default)]
pub struct FakeLanguageModel {
+ provider_id: LanguageModelProviderId,
+ provider_name: LanguageModelProviderName,
current_completion_txs: Mutex)>>,
}
+impl Default for FakeLanguageModel {
+ fn default() -> Self {
+ Self {
+ provider_id: LanguageModelProviderId::from("fake".to_string()),
+ provider_name: LanguageModelProviderName::from("Fake".to_string()),
+ current_completion_txs: Mutex::new(Vec::new()),
+ }
+ }
+}
+
impl FakeLanguageModel {
pub fn pending_completions(&self) -> Vec {
self.current_completion_txs
@@ -138,19 +149,19 @@ impl FakeLanguageModel {
impl LanguageModel for FakeLanguageModel {
fn id(&self) -> LanguageModelId {
- language_model_id()
+ LanguageModelId::from("fake".to_string())
}
fn name(&self) -> LanguageModelName {
- language_model_name()
+ LanguageModelName::from("Fake".to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
- provider_id()
+ self.provider_id.clone()
}
fn provider_name(&self) -> LanguageModelProviderName {
- provider_name()
+ self.provider_name.clone()
}
fn supports_tools(&self) -> bool {
diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs
index 72455b3821..54640419b6 100644
--- a/crates/language_model/src/language_model.rs
+++ b/crates/language_model/src/language_model.rs
@@ -735,6 +735,18 @@ impl From for LanguageModelProviderName {
}
}
+impl From> for LanguageModelProviderId {
+ fn from(value: Arc) -> Self {
+ Self(SharedString::from(value))
+ }
+}
+
+impl From> for LanguageModelProviderName {
+ fn from(value: Arc) -> Self {
+ Self(SharedString::from(value))
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs
index 6e8e8e9108..7cf071808a 100644
--- a/crates/language_model/src/registry.rs
+++ b/crates/language_model/src/registry.rs
@@ -125,7 +125,7 @@ impl LanguageModelRegistry {
#[cfg(any(test, feature = "test-support"))]
pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
- let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
+ let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
let registry = cx.new(|cx| {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
@@ -403,16 +403,17 @@ mod tests {
fn test_register_providers(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
+ let provider = FakeLanguageModelProvider::default();
registry.update(cx, |registry, cx| {
- registry.register_provider(FakeLanguageModelProvider, cx);
+ registry.register_provider(provider.clone(), cx);
});
let providers = registry.read(cx).providers();
assert_eq!(providers.len(), 1);
- assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
+ assert_eq!(providers[0].id(), provider.id());
registry.update(cx, |registry, cx| {
- registry.unregister_provider(crate::fake_provider::provider_id(), cx);
+ registry.unregister_provider(provider.id(), cx);
});
let providers = registry.read(cx).providers();
diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml
index ed38ac7660..574579aaa7 100644
--- a/crates/language_models/Cargo.toml
+++ b/crates/language_models/Cargo.toml
@@ -26,10 +26,10 @@ client.workspace = true
collections.workspace = true
component.workspace = true
credentials_provider.workspace = true
+convert_case.workspace = true
copilot.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
editor.workspace = true
-fs.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs
index 192f5a5fae..18e6f47ed0 100644
--- a/crates/language_models/src/language_models.rs
+++ b/crates/language_models/src/language_models.rs
@@ -1,8 +1,10 @@
use std::sync::Arc;
+use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
+use collections::HashSet;
use gpui::{App, Context, Entity};
-use language_model::LanguageModelRegistry;
+use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
pub mod provider;
@@ -18,17 +20,81 @@ use crate::provider::lmstudio::LmStudioLanguageModelProvider;
use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
+use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
use crate::provider::open_router::OpenRouterLanguageModelProvider;
use crate::provider::vercel::VercelLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity, client: Arc, cx: &mut App) {
- crate::settings::init(cx);
+ crate::settings::init_settings(cx);
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
- register_language_model_providers(registry, user_store, client, cx);
+ register_language_model_providers(registry, user_store, client.clone(), cx);
});
+
+ let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
+ .openai_compatible
+ .keys()
+ .cloned()
+ .collect::>();
+
+ registry.update(cx, |registry, cx| {
+ register_openai_compatible_providers(
+ registry,
+ &HashSet::default(),
+ &openai_compatible_providers,
+ client.clone(),
+ cx,
+ );
+ });
+ cx.observe_global::(move |cx| {
+ let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
+ .openai_compatible
+ .keys()
+ .cloned()
+ .collect::>();
+ if openai_compatible_providers_new != openai_compatible_providers {
+ registry.update(cx, |registry, cx| {
+ register_openai_compatible_providers(
+ registry,
+ &openai_compatible_providers,
+ &openai_compatible_providers_new,
+ client.clone(),
+ cx,
+ );
+ });
+ openai_compatible_providers = openai_compatible_providers_new;
+ }
+ })
+ .detach();
+}
+
+fn register_openai_compatible_providers(
+ registry: &mut LanguageModelRegistry,
+ old: &HashSet>,
+ new: &HashSet>,
+ client: Arc,
+ cx: &mut Context,
+) {
+ for provider_id in old {
+ if !new.contains(provider_id) {
+ registry.unregister_provider(LanguageModelProviderId::from(provider_id.clone()), cx);
+ }
+ }
+
+ for provider_id in new {
+ if !old.contains(provider_id) {
+ registry.register_provider(
+ OpenAiCompatibleLanguageModelProvider::new(
+ provider_id.clone(),
+ client.http_client(),
+ cx,
+ ),
+ cx,
+ );
+ }
+ }
}
fn register_language_model_providers(
diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs
index c717be7c90..d780195c66 100644
--- a/crates/language_models/src/provider.rs
+++ b/crates/language_models/src/provider.rs
@@ -8,6 +8,7 @@ pub mod lmstudio;
pub mod mistral;
pub mod ollama;
pub mod open_ai;
+pub mod open_ai_compatible;
pub mod open_router;
pub mod vercel;
pub mod x_ai;
diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs
index 76f2fbe303..6c4d4c9b3e 100644
--- a/crates/language_models/src/provider/open_ai.rs
+++ b/crates/language_models/src/provider/open_ai.rs
@@ -2,7 +2,6 @@ use anyhow::{Context as _, Result, anyhow};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
-use fs::Fs;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
@@ -18,7 +17,7 @@ use menu;
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use settings::{Settings, SettingsStore, update_settings_file};
+use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::Arc;
@@ -28,7 +27,6 @@ use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
-use crate::OpenAiSettingsContent;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
@@ -621,26 +619,32 @@ struct RawToolCall {
arguments: String,
}
+pub(crate) fn collect_tiktoken_messages(
+ request: LanguageModelRequest,
+) -> Vec {
+ request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::>()
+}
+
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
model: Model,
cx: &App,
) -> BoxFuture<'static, Result> {
cx.background_spawn(async move {
- let messages = request
- .messages
- .into_iter()
- .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
- role: match message.role {
- Role::User => "user".into(),
- Role::Assistant => "assistant".into(),
- Role::System => "system".into(),
- },
- content: Some(message.string_contents()),
- name: None,
- function_call: None,
- })
- .collect::>();
+ let messages = collect_tiktoken_messages(request);
match model {
Model::Custom { max_tokens, .. } => {
@@ -678,7 +682,6 @@ pub fn count_open_ai_tokens(
struct ConfigurationView {
api_key_editor: Entity,
- api_url_editor: Entity,
state: gpui::Entity,
load_credentials_task: Option>,
}
@@ -691,23 +694,6 @@ impl ConfigurationView {
cx,
"sk-000000000000000000000000000000000000000000000000",
)
- .label("API key")
- });
-
- let api_url = AllLanguageModelSettings::get_global(cx)
- .openai
- .api_url
- .clone();
-
- let api_url_editor = cx.new(|cx| {
- let input = SingleLineInput::new(window, cx, open_ai::OPEN_AI_API_URL).label("API URL");
-
- if !api_url.is_empty() {
- input.editor.update(cx, |editor, cx| {
- editor.set_text(&*api_url, window, cx);
- });
- }
- input
});
cx.observe(&state, |_, _, cx| {
@@ -735,7 +721,6 @@ impl ConfigurationView {
Self {
api_key_editor,
- api_url_editor,
state,
load_credentials_task,
}
@@ -783,57 +768,6 @@ impl ConfigurationView {
cx.notify();
}
- fn save_api_url(&mut self, cx: &mut Context) {
- let api_url = self
- .api_url_editor
- .read(cx)
- .editor()
- .read(cx)
- .text(cx)
- .trim()
- .to_string();
-
- let current_url = AllLanguageModelSettings::get_global(cx)
- .openai
- .api_url
- .clone();
-
- let effective_current_url = if current_url.is_empty() {
- open_ai::OPEN_AI_API_URL
- } else {
- ¤t_url
- };
-
- if !api_url.is_empty() && api_url != effective_current_url {
- let fs = ::global(cx);
- update_settings_file::(fs, cx, move |settings, _| {
- if let Some(settings) = settings.openai.as_mut() {
- settings.api_url = Some(api_url.clone());
- } else {
- settings.openai = Some(OpenAiSettingsContent {
- api_url: Some(api_url.clone()),
- available_models: None,
- });
- }
- });
- }
- }
-
- fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context) {
- self.api_url_editor.update(cx, |input, cx| {
- input.editor.update(cx, |editor, cx| {
- editor.set_text("", window, cx);
- });
- });
- let fs = ::global(cx);
- update_settings_file::(fs, cx, |settings, _cx| {
- if let Some(settings) = settings.openai.as_mut() {
- settings.api_url = None;
- }
- });
- cx.notify();
- }
-
fn should_render_editor(&self, cx: &mut Context) -> bool {
!self.state.read(cx).is_authenticated()
}
@@ -846,7 +780,6 @@ impl Render for ConfigurationView {
let api_key_section = if self.should_render_editor(cx) {
v_flex()
.on_action(cx.listener(Self::save_api_key))
-
.child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:"))
.child(
List::new()
@@ -910,59 +843,34 @@ impl Render for ConfigurationView {
.into_any()
};
- let custom_api_url_set =
- AllLanguageModelSettings::get_global(cx).openai.api_url != open_ai::OPEN_AI_API_URL;
-
- let api_url_section = if custom_api_url_set {
- h_flex()
- .mt_1()
- .p_1()
- .justify_between()
- .rounded_md()
- .border_1()
- .border_color(cx.theme().colors().border)
- .bg(cx.theme().colors().background)
- .child(
- h_flex()
- .gap_1()
- .child(Icon::new(IconName::Check).color(Color::Success))
- .child(Label::new("Custom API URL configured.")),
- )
- .child(
- Button::new("reset-api-url", "Reset API URL")
- .label_size(LabelSize::Small)
- .icon(IconName::Undo)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
- .layer(ElevationIndex::ModalSurface)
- .on_click(
- cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
- ),
- )
- .into_any()
- } else {
- v_flex()
- .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
- this.save_api_url(cx);
- cx.notify();
- }))
- .mt_2()
- .pt_2()
- .border_t_1()
- .border_color(cx.theme().colors().border_variant)
- .gap_1()
- .child(
- List::new()
- .child(InstructionListItem::text_only(
- "Optionally, you can change the base URL for the OpenAI API request.",
- ))
- .child(InstructionListItem::text_only(
- "Paste the new API endpoint below and hit enter",
- )),
- )
- .child(self.api_url_editor.clone())
- .into_any()
- };
+ let compatible_api_section = h_flex()
+ .mt_1p5()
+ .gap_0p5()
+ .flex_wrap()
+ .when(self.should_render_editor(cx), |this| {
+ this.pt_1p5()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ })
+ .child(
+ h_flex()
+ .gap_2()
+ .child(
+ Icon::new(IconName::Info)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
+ .child(Label::new("Zed also supports OpenAI-compatible models.")),
+ )
+ .child(
+ Button::new("docs", "Learn More")
+ .icon(IconName::ArrowUpRight)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Muted)
+ .on_click(move |_, _window, cx| {
+ cx.open_url("https://zed.dev/docs/ai/configuration#openai-api-compatible")
+ }),
+ );
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials…")).into_any()
@@ -970,7 +878,7 @@ impl Render for ConfigurationView {
v_flex()
.size_full()
.child(api_key_section)
- .child(api_url_section)
+ .child(compatible_api_section)
.into_any()
}
}
diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs
new file mode 100644
index 0000000000..64add5483d
--- /dev/null
+++ b/crates/language_models/src/provider/open_ai_compatible.rs
@@ -0,0 +1,522 @@
+use anyhow::{Context as _, Result, anyhow};
+use credentials_provider::CredentialsProvider;
+
+use convert_case::{Case, Casing};
+use futures::{FutureExt, StreamExt, future::BoxFuture};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
+use http_client::HttpClient;
+use language_model::{
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+ LanguageModelToolChoice, RateLimiter,
+};
+use menu;
+use open_ai::{ResponseStreamEvent, stream_completion};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::{Settings, SettingsStore};
+use std::sync::Arc;
+
+use ui::{ElevationIndex, Tooltip, prelude::*};
+use ui_input::SingleLineInput;
+use util::ResultExt;
+
+use crate::AllLanguageModelSettings;
+use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct OpenAiCompatibleSettings {
+ pub api_url: String,
+ pub available_models: Vec,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct AvailableModel {
+ pub name: String,
+ pub display_name: Option,
+ pub max_tokens: u64,
+ pub max_output_tokens: Option,
+ pub max_completion_tokens: Option,
+}
+
+pub struct OpenAiCompatibleLanguageModelProvider {
+ id: LanguageModelProviderId,
+ name: LanguageModelProviderName,
+ http_client: Arc,
+ state: gpui::Entity,
+}
+
+pub struct State {
+ id: Arc,
+ env_var_name: Arc,
+ api_key: Option,
+ api_key_from_env: bool,
+ settings: OpenAiCompatibleSettings,
+ _subscription: Subscription,
+}
+
+impl State {
+ fn is_authenticated(&self) -> bool {
+ self.api_key.is_some()
+ }
+
+ fn reset_api_key(&self, cx: &mut Context) -> Task> {
+ let credentials_provider = ::global(cx);
+ let api_url = self.settings.api_url.clone();
+ cx.spawn(async move |this, cx| {
+ credentials_provider
+ .delete_credentials(&api_url, &cx)
+ .await
+ .log_err();
+ this.update(cx, |this, cx| {
+ this.api_key = None;
+ this.api_key_from_env = false;
+ cx.notify();
+ })
+ })
+ }
+
+ fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> {
+ let credentials_provider = ::global(cx);
+ let api_url = self.settings.api_url.clone();
+ cx.spawn(async move |this, cx| {
+ credentials_provider
+ .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
+ .await
+ .log_err();
+ this.update(cx, |this, cx| {
+ this.api_key = Some(api_key);
+ cx.notify();
+ })
+ })
+ }
+
+ fn authenticate(&self, cx: &mut Context) -> Task> {
+ if self.is_authenticated() {
+ return Task::ready(Ok(()));
+ }
+
+ let credentials_provider = ::global(cx);
+ let env_var_name = self.env_var_name.clone();
+ let api_url = self.settings.api_url.clone();
+ cx.spawn(async move |this, cx| {
+ let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
+ (api_key, true)
+ } else {
+ let (_, api_key) = credentials_provider
+ .read_credentials(&api_url, &cx)
+ .await?
+ .ok_or(AuthenticateError::CredentialsNotFound)?;
+ (
+ String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
+ false,
+ )
+ };
+ this.update(cx, |this, cx| {
+ this.api_key = Some(api_key);
+ this.api_key_from_env = from_env;
+ cx.notify();
+ })?;
+
+ Ok(())
+ })
+ }
+}
+
+impl OpenAiCompatibleLanguageModelProvider {
+ pub fn new(id: Arc, http_client: Arc, cx: &mut App) -> Self {
+ fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
+ AllLanguageModelSettings::get_global(cx)
+ .openai_compatible
+ .get(id)
+ }
+
+ let state = cx.new(|cx| State {
+ id: id.clone(),
+ env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
+ settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
+ api_key: None,
+ api_key_from_env: false,
+ _subscription: cx.observe_global::(|this: &mut State, cx| {
+ let Some(settings) = resolve_settings(&this.id, cx) else {
+ return;
+ };
+ if &this.settings != settings {
+ this.settings = settings.clone();
+ cx.notify();
+ }
+ }),
+ });
+
+ Self {
+ id: id.clone().into(),
+ name: id.into(),
+ http_client,
+ state,
+ }
+ }
+
+ fn create_language_model(&self, model: AvailableModel) -> Arc {
+ Arc::new(OpenAiCompatibleLanguageModel {
+ id: LanguageModelId::from(model.name.clone()),
+ provider_id: self.id.clone(),
+ provider_name: self.name.clone(),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+}
+
+impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ self.name.clone()
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::AiOpenAiCompat
+ }
+
+ fn default_model(&self, cx: &App) -> Option> {
+ self.state
+ .read(cx)
+ .settings
+ .available_models
+ .first()
+ .map(|model| self.create_language_model(model.clone()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option> {
+ None
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec> {
+ self.state
+ .read(cx)
+ .settings
+ .available_models
+ .iter()
+ .map(|model| self.create_language_model(model.clone()))
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated()
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task> {
+ self.state.update(cx, |state, cx| state.authenticate(cx))
+ }
+
+ fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &mut App) -> Task> {
+ self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ }
+}
+
+pub struct OpenAiCompatibleLanguageModel {
+ id: LanguageModelId,
+ provider_id: LanguageModelProviderId,
+ provider_name: LanguageModelProviderName,
+ model: AvailableModel,
+ state: gpui::Entity,
+ http_client: Arc,
+ request_limiter: RateLimiter,
+}
+
+impl OpenAiCompatibleLanguageModel {
+ fn stream_completion(
+ &self,
+ request: open_ai::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<'static, Result>>>
+ {
+ let http_client = self.http_client.clone();
+ let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
+ (state.api_key.clone(), state.settings.api_url.clone())
+ }) else {
+ return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ let provider = self.provider_name.clone();
+ let future = self.request_limiter.stream(async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey { provider });
+ };
+ let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
+impl LanguageModel for OpenAiCompatibleLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(
+ self.model
+ .display_name
+ .clone()
+ .unwrap_or_else(|| self.model.name.clone()),
+ )
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ self.provider_id.clone()
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ self.provider_name.clone()
+ }
+
+ fn supports_tools(&self) -> bool {
+ true
+ }
+
+ fn supports_images(&self) -> bool {
+ false
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto => true,
+ LanguageModelToolChoice::Any => true,
+ LanguageModelToolChoice::None => true,
+ }
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("openai/{}", self.model.name)
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_tokens
+ }
+
+ fn max_output_tokens(&self) -> Option {
+ self.model.max_output_tokens
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result> {
+ let max_token_count = self.max_token_count();
+ cx.background_spawn(async move {
+ let messages = super::open_ai::collect_tiktoken_messages(request);
+ let model = if max_token_count >= 100_000 {
+ // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
+ "gpt-4o"
+ } else {
+ // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
+ // supported with this tiktoken method
+ "gpt-4"
+ };
+ tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
+ })
+ .boxed()
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens());
+ let completions = self.stream_completion(request, cx);
+ async move {
+ let mapper = OpenAiEventMapper::new();
+ Ok(mapper.map_stream(completions.await?).boxed())
+ }
+ .boxed()
+ }
+}
+
+struct ConfigurationView {
+ api_key_editor: Entity,
+ state: gpui::Entity,
+ load_credentials_task: Option>,
+}
+
+impl ConfigurationView {
+ fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self {
+ let api_key_editor = cx.new(|cx| {
+ SingleLineInput::new(
+ window,
+ cx,
+ "000000000000000000000000000000000000000000000000000",
+ )
+ });
+
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
+
+ let load_credentials_task = Some(cx.spawn_in(window, {
+ let state = state.clone();
+ async move |this, cx| {
+ if let Some(task) = state
+ .update(cx, |state, cx| state.authenticate(cx))
+ .log_err()
+ {
+ // We don't log an error, because "not signed in" is also an error.
+ let _ = task.await;
+ }
+ this.update(cx, |this, cx| {
+ this.load_credentials_task = None;
+ cx.notify();
+ })
+ .log_err();
+ }
+ }));
+
+ Self {
+ api_key_editor,
+ state,
+ load_credentials_task,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) {
+ let api_key = self
+ .api_key_editor
+ .read(cx)
+ .editor()
+ .read(cx)
+ .text(cx)
+ .trim()
+ .to_string();
+
+ // Don't proceed if no API key is provided and we're not authenticated
+ if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
+ return;
+ }
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+
+ cx.notify();
+ }
+
+ fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) {
+ self.api_key_editor.update(cx, |input, cx| {
+ input.editor.update(cx, |editor, cx| {
+ editor.set_text("", window, cx);
+ });
+ });
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ })
+ .detach_and_log_err(cx);
+
+ cx.notify();
+ }
+
+ fn should_render_editor(&self, cx: &mut Context) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).api_key_from_env;
+ let env_var_name = self.state.read(cx).env_var_name.clone();
+
+ let api_key_section = if self.should_render_editor(cx) {
+ v_flex()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new("To use Zed's assistant with an OpenAI compatible provider, you need to add an API key."))
+ .child(
+ div()
+ .pt(DynamicSpacing::Base04.rems(cx))
+ .child(self.api_key_editor.clone())
+ )
+ .child(
+ Label::new(
+ format!("You can also assign the {env_var_name} environment variable and restart Zed."),
+ )
+ .size(LabelSize::Small).color(Color::Muted),
+ )
+ .into_any()
+ } else {
+ h_flex()
+ .mt_1()
+ .p_1()
+ .justify_between()
+ .rounded_md()
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().background)
+ .child(
+ h_flex()
+ .gap_1()
+ .child(Icon::new(IconName::Check).color(Color::Success))
+ .child(Label::new(if env_var_set {
+ format!("API key set in {env_var_name} environment variable.")
+ } else {
+ "API key configured.".to_string()
+ })),
+ )
+ .child(
+ Button::new("reset-api-key", "Reset API Key")
+ .label_size(LabelSize::Small)
+ .icon(IconName::Undo)
+ .icon_size(IconSize::Small)
+ .icon_position(IconPosition::Start)
+ .layer(ElevationIndex::ModalSurface)
+ .when(env_var_set, |this| {
+ this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
+ })
+ .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
+ )
+ .into_any()
+ };
+
+ if self.load_credentials_task.is_some() {
+ div().child(Label::new("Loading credentials…")).into_any()
+ } else {
+ v_flex().size_full().child(api_key_section).into_any()
+ }
+ }
+}
diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs
index dafbb62910..b163585aa7 100644
--- a/crates/language_models/src/settings.rs
+++ b/crates/language_models/src/settings.rs
@@ -1,4 +1,7 @@
+use std::sync::Arc;
+
use anyhow::Result;
+use collections::HashMap;
use gpui::App;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -15,13 +18,14 @@ use crate::provider::{
mistral::MistralSettings,
ollama::OllamaSettings,
open_ai::OpenAiSettings,
+ open_ai_compatible::OpenAiCompatibleSettings,
open_router::OpenRouterSettings,
vercel::VercelSettings,
x_ai::XAiSettings,
};
/// Initializes the language model settings.
-pub fn init(cx: &mut App) {
+pub fn init_settings(cx: &mut App) {
AllLanguageModelSettings::register(cx);
}
@@ -36,6 +40,7 @@ pub struct AllLanguageModelSettings {
pub ollama: OllamaSettings,
pub open_router: OpenRouterSettings,
pub openai: OpenAiSettings,
+ pub openai_compatible: HashMap, OpenAiCompatibleSettings>,
pub vercel: VercelSettings,
pub x_ai: XAiSettings,
pub zed_dot_dev: ZedDotDevSettings,
@@ -52,6 +57,7 @@ pub struct AllLanguageModelSettingsContent {
pub ollama: Option,
pub open_router: Option,
pub openai: Option,
+ pub openai_compatible: Option, OpenAiCompatibleSettingsContent>>,
pub vercel: Option,
pub x_ai: Option,
#[serde(rename = "zed.dev")]
@@ -103,6 +109,12 @@ pub struct OpenAiSettingsContent {
pub available_models: Option>,
}
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+pub struct OpenAiCompatibleSettingsContent {
+ pub api_url: String,
+ pub available_models: Vec,
+}
+
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct VercelSettingsContent {
pub api_url: Option,
@@ -226,6 +238,19 @@ impl settings::Settings for AllLanguageModelSettings {
openai.as_ref().and_then(|s| s.available_models.clone()),
);
+ // OpenAI Compatible
+ if let Some(openai_compatible) = value.openai_compatible.clone() {
+ for (id, openai_compatible_settings) in openai_compatible {
+ settings.openai_compatible.insert(
+ id,
+ OpenAiCompatibleSettings {
+ api_url: openai_compatible_settings.api_url,
+ available_models: openai_compatible_settings.available_models,
+ },
+ );
+ }
+ }
+
// Vercel
let vercel = value.vercel.clone();
merge(
diff --git a/crates/ui/src/components/modal.rs b/crates/ui/src/components/modal.rs
index 2e926b7593..2145b34ef2 100644
--- a/crates/ui/src/components/modal.rs
+++ b/crates/ui/src/components/modal.rs
@@ -93,6 +93,7 @@ impl RenderOnce for Modal {
#[derive(IntoElement)]
pub struct ModalHeader {
headline: Option,
+ description: Option,
children: SmallVec<[AnyElement; 2]>,
show_dismiss_button: bool,
show_back_button: bool,
@@ -108,6 +109,7 @@ impl ModalHeader {
pub fn new() -> Self {
Self {
headline: None,
+ description: None,
children: SmallVec::new(),
show_dismiss_button: false,
show_back_button: false,
@@ -123,6 +125,11 @@ impl ModalHeader {
self
}
+ pub fn description(mut self, description: impl Into) -> Self {
+ self.description = Some(description.into());
+ self
+ }
+
pub fn show_dismiss_button(mut self, show: bool) -> Self {
self.show_dismiss_button = show;
self
@@ -171,7 +178,14 @@ impl RenderOnce for ModalHeader {
}),
)
})
- .child(div().flex_1().children(children))
+ .child(
+ v_flex().flex_1().children(children).when_some(
+ self.description,
+ |this, description| {
+ this.child(Label::new(description).color(Color::Muted).mb_2())
+ },
+ ),
+ )
.when(self.show_dismiss_button, |this| {
this.child(
IconButton::new("dismiss", IconName::Close)
diff --git a/crates/ui_input/src/ui_input.rs b/crates/ui_input/src/ui_input.rs
index 18aa732e81..309b3f62f6 100644
--- a/crates/ui_input/src/ui_input.rs
+++ b/crates/ui_input/src/ui_input.rs
@@ -97,6 +97,10 @@ impl SingleLineInput {
pub fn editor(&self) -> &Entity {
&self.editor
}
+
+ pub fn text(&self, cx: &App) -> String {
+ self.editor().read(cx).text(cx)
+ }
}
impl Render for SingleLineInput {
diff --git a/docs/src/ai/configuration.md b/docs/src/ai/configuration.md
index 1201fa2173..414da2206f 100644
--- a/docs/src/ai/configuration.md
+++ b/docs/src/ai/configuration.md
@@ -444,14 +444,17 @@ Custom models will be listed in the model dropdown in the Agent Panel.
### OpenAI API Compatible {#openai-api-compatible}
-Zed supports using OpenAI compatible APIs by specifying a custom `endpoint` and `available_models` for the OpenAI provider.
+Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models.
-Zed supports using OpenAI compatible APIs by specifying a custom `api_url` and `available_models` for the OpenAI provider. This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models.
+To configure a compatible API, you can add a custom API URL for OpenAI either via the UI (currently available only in Preview) or by editing your `settings.json`.
-To configure a compatible API, you can add a custom API URL for OpenAI either via the UI or by editing your `settings.json`. For example, to connect to [Together AI](https://www.together.ai/):
+For example, to connect to [Together AI](https://www.together.ai/) via the UI:
-1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys).
-2. Add the following to your `settings.json`:
+1. Get an API key from your [Together AI account](https://api.together.ai/settings/api-keys).
+2. Go to the Agent Panel's settings view, click on the "Add Provider" button, and then on the "OpenAI" menu item
+3. Add the requested fields, such as `api_url`, `api_key`, available models, and others
+
+Alternatively, you can also add it via the `settings.json`:
```json
{