From 43cb925a599ead7ddbeee26efd1f5af1d1dea97a Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 4 Apr 2025 11:40:55 -0300 Subject: [PATCH] ai: Separate model settings for each feature (#28088) Closes: https://github.com/zed-industries/zed/issues/20582 Allows users to select a specific model for each AI-powered feature: - Agent panel - Inline assistant - Thread summarization - Commit message generation If unspecified for a given feature, it will use the `default_model` setting. Release Notes: - Added support for configuring a specific model for each AI-powered feature --------- Co-authored-by: Danilo Leal Co-authored-by: Bennet Bo Fenner --- crates/agent/src/active_thread.rs | 22 +- .../assistant_configuration/tool_picker.rs | 72 ++-- crates/agent/src/assistant_model_selector.rs | 52 ++- crates/agent/src/assistant_panel.rs | 16 +- crates/agent/src/buffer_codegen.rs | 5 +- crates/agent/src/inline_assistant.rs | 22 +- crates/agent/src/inline_prompt_editor.rs | 6 +- crates/agent/src/message_editor.rs | 22 +- crates/agent/src/profile_selector.rs | 4 +- crates/agent/src/terminal_codegen.rs | 8 +- crates/agent/src/terminal_inline_assistant.rs | 8 +- crates/agent/src/thread.rs | 25 +- crates/assistant/src/assistant.rs | 56 +++- crates/assistant/src/assistant_panel.rs | 33 +- crates/assistant/src/inline_assistant.rs | 31 +- .../src/terminal_inline_assistant.rs | 24 +- .../assistant_context_editor/src/context.rs | 25 +- .../src/context_editor.rs | 36 +- .../assistant_eval/src/headless_assistant.rs | 4 +- crates/assistant_eval/src/main.rs | 16 +- .../src/assistant_settings.rs | 310 +++++++++++------- crates/assistant_tools/src/edit_files_tool.rs | 6 +- crates/git_ui/src/git_panel.rs | 8 +- crates/language_model/src/registry.rs | 151 ++++++--- .../src/language_model_selector.rs | 13 +- crates/prompt_library/src/prompt_library.rs | 14 +- docs/src/assistant/configuration.md | 62 +++- 27 files changed, 670 insertions(+), 381 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index dd2225ce71..f645fdfdcc 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -21,7 +21,7 @@ use gpui::{ linear_color_stop, linear_gradient, list, percentage, pulsating_between, }; use language::{Buffer, LanguageRegistry}; -use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; +use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role}; use markdown::{Markdown, MarkdownStyle}; use project::ProjectItem as _; use settings::{Settings as _, update_settings_file}; @@ -606,7 +606,7 @@ impl ActiveThread { if self.thread.read(cx).all_tools_finished() { let model_registry = LanguageModelRegistry::read_global(cx); - if let Some(model) = model_registry.active_model() { + if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { self.thread.update(cx, |thread, cx| { thread.attach_tool_results(cx); if !canceled { @@ -814,21 +814,17 @@ impl ActiveThread { } }); - let provider = LanguageModelRegistry::read_global(cx).active_provider(); - if provider - .as_ref() - .map_or(false, |provider| provider.must_accept_terms(cx)) - { - cx.notify(); - return; - } - let model_registry = LanguageModelRegistry::read_global(cx); - let Some(model) = model_registry.active_model() else { + let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { return; }; + if model.provider.must_accept_terms(cx) { + cx.notify(); + return; + } + self.thread.update(cx, |thread, cx| { - thread.send_to_model(model, RequestKind::Chat, cx) + thread.send_to_model(model.model, RequestKind::Chat, cx) }); cx.notify(); } diff --git a/crates/agent/src/assistant_configuration/tool_picker.rs b/crates/agent/src/assistant_configuration/tool_picker.rs index 7ca6747a8e..eabd9e172b 100644 --- a/crates/agent/src/assistant_configuration/tool_picker.rs +++ b/crates/agent/src/assistant_configuration/tool_picker.rs @@ -202,43 +202,43 @@ impl PickerDelegate for ToolPickerDelegate { let default_profile = self.profile.clone(); let tool = tool.clone(); move |settings, _cx| match settings { - AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2( - settings, - )) => { - let profiles = settings.profiles.get_or_insert_default(); - let profile = - profiles - .entry(profile_id) - .or_insert_with(|| AgentProfileContent { - name: default_profile.name.into(), - tools: default_profile.tools, - enable_all_context_servers: Some( - default_profile.enable_all_context_servers, - ), - context_servers: default_profile - .context_servers - .into_iter() - .map(|(server_id, preset)| { - ( - server_id, - ContextServerPresetContent { - tools: preset.tools, - }, - ) - }) - .collect(), - }); + AssistantSettingsContent::Versioned(boxed) => { + if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed { + let profiles = settings.profiles.get_or_insert_default(); + let profile = + profiles + .entry(profile_id) + .or_insert_with(|| AgentProfileContent { + name: default_profile.name.into(), + tools: default_profile.tools, + enable_all_context_servers: Some( + default_profile.enable_all_context_servers, + ), + context_servers: default_profile + .context_servers + .into_iter() + .map(|(server_id, preset)| { + ( + server_id, + ContextServerPresetContent { + tools: preset.tools, + }, + ) + }) + .collect(), + }); - match tool.source { - ToolSource::Native => { - *profile.tools.entry(tool.name).or_default() = is_enabled; - } - ToolSource::ContextServer { id } => { - let preset = profile - .context_servers - .entry(id.clone().into()) - .or_default(); - *preset.tools.entry(tool.name.clone()).or_default() = is_enabled; + match tool.source { + ToolSource::Native => { + *profile.tools.entry(tool.name).or_default() = is_enabled; + } + ToolSource::ContextServer { id } => { + let preset = profile + .context_servers + .entry(id.clone().into()) + .or_default(); + *preset.tools.entry(tool.name.clone()).or_default() = is_enabled; + } } } } diff --git a/crates/agent/src/assistant_model_selector.rs b/crates/agent/src/assistant_model_selector.rs index a8da7ca16e..11726b2574 100644 --- a/crates/agent/src/assistant_model_selector.rs +++ b/crates/agent/src/assistant_model_selector.rs @@ -9,10 +9,17 @@ use settings::update_settings_file; use std::sync::Arc; use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*}; +#[derive(Clone, Copy)] +pub enum ModelType { + Default, + InlineAssistant, +} + pub struct AssistantModelSelector { selector: Entity, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, + model_type: ModelType, } impl AssistantModelSelector { @@ -20,6 +27,7 @@ impl AssistantModelSelector { fs: Arc, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, + model_type: ModelType, window: &mut Window, cx: &mut App, ) -> Self { @@ -28,11 +36,32 @@ impl AssistantModelSelector { let fs = fs.clone(); LanguageModelSelector::new( move |model, cx| { - update_settings_file::( - fs.clone(), - cx, - move |settings, _cx| settings.set_model(model.clone()), - ); + let provider = model.provider_id().0.to_string(); + let model_id = model.id().0.to_string(); + + match model_type { + ModelType::Default => { + update_settings_file::( + fs.clone(), + cx, + move |settings, _cx| { + settings.set_model(model.clone()); + }, + ); + } + ModelType::InlineAssistant => { + update_settings_file::( + fs.clone(), + cx, + move |settings, _cx| { + settings.set_inline_assistant_model( + provider.clone(), + model_id.clone(), + ); + }, + ); + } + } }, window, cx, @@ -40,6 +69,7 @@ impl AssistantModelSelector { }), menu_handle, focus_handle, + model_type, } } @@ -50,10 +80,16 @@ impl AssistantModelSelector { impl Render for AssistantModelSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let active_model = LanguageModelRegistry::read_global(cx).active_model(); + let model_registry = LanguageModelRegistry::read_global(cx); + + let model = match self.model_type { + ModelType::Default => model_registry.default_model(), + ModelType::InlineAssistant => model_registry.inline_assistant_model(), + }; + let focus_handle = self.focus_handle.clone(); - let model_name = match active_model { - Some(model) => model.name().0, + let model_name = match model { + Some(model) => model.model.name().0, _ => SharedString::from("No model selected"), }; diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index e11ac31586..23417a3e29 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -571,10 +571,8 @@ impl AssistantPanel { match event { AssistantConfigurationEvent::NewThread(provider) => { if LanguageModelRegistry::read_global(cx) - .active_provider() - .map_or(true, |active_provider| { - active_provider.id() != provider.id() - }) + .default_model() + .map_or(true, |model| model.provider.id() != provider.id()) { if let Some(model) = provider.default_model(cx) { update_settings_file::( @@ -922,16 +920,18 @@ impl AssistantPanel { } fn configuration_error(&self, cx: &App) -> Option { - let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else { + let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { return Some(ConfigurationError::NoProvider); }; - if !provider.is_authenticated(cx) { + if !model.provider.is_authenticated(cx) { return Some(ConfigurationError::ProviderNotAuthenticated); } - if provider.must_accept_terms(cx) { - return Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)); + if model.provider.must_accept_terms(cx) { + return Some(ConfigurationError::ProviderPendingTermsAcceptance( + model.provider, + )); } None diff --git a/crates/agent/src/buffer_codegen.rs b/crates/agent/src/buffer_codegen.rs index e1ef0ecac8..42a42ffa0f 100644 --- a/crates/agent/src/buffer_codegen.rs +++ b/crates/agent/src/buffer_codegen.rs @@ -156,8 +156,9 @@ impl BufferCodegen { } let primary_model = LanguageModelRegistry::read_global(cx) - .active_model() - .context("no active model")?; + .default_model() + .context("no active model")? + .model; for (model, alternative) in iter::once(primary_model) .chain(alternative_models) diff --git a/crates/agent/src/inline_assistant.rs b/crates/agent/src/inline_assistant.rs index 33a7d1f891..b0356f9048 100644 --- a/crates/agent/src/inline_assistant.rs +++ b/crates/agent/src/inline_assistant.rs @@ -239,8 +239,8 @@ impl InlineAssistant { let is_authenticated = || { LanguageModelRegistry::read_global(cx) - .active_provider() - .map_or(false, |provider| provider.is_authenticated(cx)) + .inline_assistant_model() + .map_or(false, |model| model.provider.is_authenticated(cx)) }; let thread_store = workspace @@ -279,8 +279,8 @@ impl InlineAssistant { cx.spawn_in(window, async move |_workspace, cx| { let Some(task) = cx.update(|_, cx| { LanguageModelRegistry::read_global(cx) - .active_provider() - .map_or(None, |provider| Some(provider.authenticate(cx))) + .inline_assistant_model() + .map_or(None, |model| Some(model.provider.authenticate(cx))) })? else { let answer = cx @@ -401,14 +401,14 @@ impl InlineAssistant { codegen_ranges.push(anchor_range); - if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { + if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() { self.telemetry.report_assistant_event(AssistantEvent { conversation_id: None, kind: AssistantKind::Inline, phase: AssistantPhase::Invoked, message_id: None, - model: model.telemetry_id(), - model_provider: model.provider_id().to_string(), + model: model.model.telemetry_id(), + model_provider: model.provider.id().to_string(), response_latency: None, error_message: None, language_name: buffer.language().map(|language| language.name().to_proto()), @@ -976,7 +976,7 @@ impl InlineAssistant { let active_alternative = assist.codegen.read(cx).active_alternative().clone(); let message_id = active_alternative.read(cx).message_id.clone(); - if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { + if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() { let language_name = assist.editor.upgrade().and_then(|editor| { let multibuffer = editor.read(cx).buffer().read(cx); let snapshot = multibuffer.snapshot(cx); @@ -996,15 +996,15 @@ impl InlineAssistant { } else { AssistantPhase::Accepted }, - model: model.telemetry_id(), - model_provider: model.provider_id().to_string(), + model: model.model.telemetry_id(), + model_provider: model.model.provider_id().to_string(), response_latency: None, error_message: None, language_name: language_name.map(|name| name.to_proto()), }, Some(self.telemetry.clone()), cx.http_client(), - model.api_key(cx), + model.model.api_key(cx), cx.background_executor(), ); } diff --git a/crates/agent/src/inline_prompt_editor.rs b/crates/agent/src/inline_prompt_editor.rs index baa72bd434..913368d0e4 100644 --- a/crates/agent/src/inline_prompt_editor.rs +++ b/crates/agent/src/inline_prompt_editor.rs @@ -1,4 +1,4 @@ -use crate::assistant_model_selector::AssistantModelSelector; +use crate::assistant_model_selector::{AssistantModelSelector, ModelType}; use crate::buffer_codegen::BufferCodegen; use crate::context_picker::ContextPicker; use crate::context_store::ContextStore; @@ -582,7 +582,7 @@ impl PromptEditor { let disabled = matches!(codegen.status(cx), CodegenStatus::Idle); let model_registry = LanguageModelRegistry::read_global(cx); - let default_model = model_registry.active_model(); + let default_model = model_registry.default_model().map(|default| default.model); let alternative_models = model_registry.inline_alternative_models(); let get_model_name = |index: usize| -> String { @@ -890,6 +890,7 @@ impl PromptEditor { fs, model_selector_menu_handle, prompt_editor.focus_handle(cx), + ModelType::InlineAssistant, window, cx, ) @@ -1042,6 +1043,7 @@ impl PromptEditor { fs, model_selector_menu_handle.clone(), prompt_editor.focus_handle(cx), + ModelType::InlineAssistant, window, cx, ) diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 7cc5b44026..3cc00c39fb 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use crate::assistant_model_selector::ModelType; use collections::HashSet; use editor::actions::MoveUp; use editor::{ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorStyle}; @@ -10,7 +11,7 @@ use gpui::{ WeakEntity, linear_color_stop, linear_gradient, point, }; use language::Buffer; -use language_model::LanguageModelRegistry; +use language_model::{ConfiguredModel, LanguageModelRegistry}; use language_model_selector::ToggleModelSelector; use multi_buffer; use project::Project; @@ -139,6 +140,7 @@ impl MessageEditor { fs.clone(), model_selector_menu_handle, editor.focus_handle(cx), + ModelType::Default, window, cx, ) @@ -191,7 +193,7 @@ impl MessageEditor { fn is_model_selected(&self, cx: &App) -> bool { LanguageModelRegistry::read_global(cx) - .active_model() + .default_model() .is_some() } @@ -201,20 +203,16 @@ impl MessageEditor { window: &mut Window, cx: &mut Context, ) { - let provider = LanguageModelRegistry::read_global(cx).active_provider(); - if provider - .as_ref() - .map_or(false, |provider| provider.must_accept_terms(cx)) - { + let model_registry = LanguageModelRegistry::read_global(cx); + let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else { + return; + }; + + if provider.must_accept_terms(cx) { cx.notify(); return; } - let model_registry = LanguageModelRegistry::read_global(cx); - let Some(model) = model_registry.active_model() else { - return; - }; - let user_message = self.editor.update(cx, |editor, cx| { let text = editor.text(cx); editor.clear(window, cx); diff --git a/crates/agent/src/profile_selector.rs b/crates/agent/src/profile_selector.rs index 46a8cf4273..dfcafba5fc 100644 --- a/crates/agent/src/profile_selector.rs +++ b/crates/agent/src/profile_selector.rs @@ -130,8 +130,8 @@ impl Render for ProfileSelector { let model_registry = LanguageModelRegistry::read_global(cx); let supports_tools = model_registry - .active_model() - .map_or(false, |model| model.supports_tools()); + .default_model() + .map_or(false, |default| default.model.supports_tools()); let icon = match profile_id.as_str() { "write" => IconName::Pencil, diff --git a/crates/agent/src/terminal_codegen.rs b/crates/agent/src/terminal_codegen.rs index a791de76b7..29f4329523 100644 --- a/crates/agent/src/terminal_codegen.rs +++ b/crates/agent/src/terminal_codegen.rs @@ -2,7 +2,9 @@ use crate::inline_prompt_editor::CodegenStatus; use client::telemetry::Telemetry; use futures::{SinkExt, StreamExt, channel::mpsc}; use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task}; -use language_model::{LanguageModelRegistry, LanguageModelRequest, report_assistant_event}; +use language_model::{ + ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, report_assistant_event, +}; use std::{sync::Arc, time::Instant}; use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase}; use terminal::Terminal; @@ -31,7 +33,9 @@ impl TerminalCodegen { } pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context) { - let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).inline_assistant_model() + else { return; }; diff --git a/crates/agent/src/terminal_inline_assistant.rs b/crates/agent/src/terminal_inline_assistant.rs index 15460609bb..b68ace831f 100644 --- a/crates/agent/src/terminal_inline_assistant.rs +++ b/crates/agent/src/terminal_inline_assistant.rs @@ -13,8 +13,8 @@ use fs::Fs; use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity}; use language::Buffer; use language_model::{ - LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, - report_assistant_event, + ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, + Role, report_assistant_event, }; use prompt_store::PromptBuilder; use std::sync::Arc; @@ -286,7 +286,9 @@ impl TerminalInlineAssistant { }) .log_err(); - if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { + if let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).inline_assistant_model() + { let codegen = assist.codegen.read(cx); let executor = cx.background_executor().clone(); report_assistant_event( diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 611c93e9b4..45c4c0c895 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -14,10 +14,10 @@ use futures::{FutureExt, StreamExt as _}; use git::repository::DiffType; use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, - Role, StopReason, TokenUsage, + ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, + PaymentRequiredError, Role, StopReason, TokenUsage, }; use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState}; use project::{Project, Worktree}; @@ -1250,14 +1250,11 @@ impl Thread { } pub fn summarize(&mut self, cx: &mut Context) { - let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else { - return; - }; - let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else { return; }; - if !provider.is_authenticated(cx) { + if !model.provider.is_authenticated(cx) { return; } @@ -1276,7 +1273,7 @@ impl Thread { self.pending_summary = cx.spawn(async move |this, cx| { async move { - let stream = model.stream_completion_text(request, &cx); + let stream = model.model.stream_completion_text(request, &cx); let mut messages = stream.await?; let mut new_summary = String::new(); @@ -1320,8 +1317,8 @@ impl Thread { _ => {} } - let provider = LanguageModelRegistry::read_global(cx).active_provider()?; - let model = LanguageModelRegistry::read_global(cx).active_model()?; + let ConfiguredModel { model, provider } = + LanguageModelRegistry::read_global(cx).thread_summary_model()?; if !provider.is_authenticated(cx) { return None; @@ -1782,11 +1779,11 @@ impl Thread { pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage { let model_registry = LanguageModelRegistry::read_global(cx); - let Some(model) = model_registry.active_model() else { + let Some(model) = model_registry.default_model() else { return TotalTokenUsage::default(); }; - let max = model.max_token_count(); + let max = model.model.max_token_count(); #[cfg(debug_assertions)] let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 093713d2e7..16b775865d 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -161,12 +161,38 @@ fn init_language_model_settings(cx: &mut App) { fn update_active_language_model_from_settings(cx: &mut App) { let settings = AssistantSettings::get_global(cx); + // Default model - used as fallback let active_model_provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone()); let active_model_id = LanguageModelId::from(settings.default_model.model.clone()); - let editor_provider_name = - LanguageModelProviderId::from(settings.editor_model.provider.clone()); - let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone()); + + // Inline assistant model + let inline_assistant_model = settings + .inline_assistant_model + .as_ref() + .unwrap_or(&settings.default_model); + let inline_assistant_provider_name = + LanguageModelProviderId::from(inline_assistant_model.provider.clone()); + let inline_assistant_model_id = LanguageModelId::from(inline_assistant_model.model.clone()); + + // Commit message model + let commit_message_model = settings + .commit_message_model + .as_ref() + .unwrap_or(&settings.default_model); + let commit_message_provider_name = + LanguageModelProviderId::from(commit_message_model.provider.clone()); + let commit_message_model_id = LanguageModelId::from(commit_message_model.model.clone()); + + // Thread summary model + let thread_summary_model = settings + .thread_summary_model + .as_ref() + .unwrap_or(&settings.default_model); + let thread_summary_provider_name = + LanguageModelProviderId::from(thread_summary_model.provider.clone()); + let thread_summary_model_id = LanguageModelId::from(thread_summary_model.model.clone()); + let inline_alternatives = settings .inline_alternatives .iter() @@ -177,9 +203,29 @@ fn update_active_language_model_from_settings(cx: &mut App) { ) }) .collect::>(); + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.select_active_model(&active_model_provider_name, &active_model_id, cx); - registry.select_editor_model(&editor_provider_name, &editor_model_id, cx); + // Set the default model + registry.select_default_model(&active_model_provider_name, &active_model_id, cx); + + // Set the specific models + registry.select_inline_assistant_model( + &inline_assistant_provider_name, + &inline_assistant_model_id, + cx, + ); + registry.select_commit_message_model( + &commit_message_provider_name, + &commit_message_model_id, + cx, + ); + registry.select_thread_summary_model( + &thread_summary_provider_name, + &thread_summary_model_id, + cx, + ); + + // Set the alternatives registry.select_inline_alternative_models(inline_alternatives, cx); }); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 1bb1ac0896..0f5e2c40e7 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -22,7 +22,8 @@ use gpui::{ }; use language::LanguageRegistry; use language_model::{ - AuthenticateError, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID, + AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry, + ZED_CLOUD_PROVIDER_ID, }; use project::Project; use prompt_library::{PromptLibrary, open_prompt_library}; @@ -298,8 +299,10 @@ impl AssistantPanel { &LanguageModelRegistry::global(cx), window, |this, _, event: &language_model::Event, window, cx| match event { - language_model::Event::ActiveModelChanged - | language_model::Event::EditorModelChanged => { + language_model::Event::DefaultModelChanged + | language_model::Event::InlineAssistantModelChanged + | language_model::Event::CommitMessageModelChanged + | language_model::Event::ThreadSummaryModelChanged => { this.completion_provider_changed(window, cx); } language_model::Event::ProviderStateChanged => { @@ -468,12 +471,12 @@ impl AssistantPanel { } fn update_zed_ai_notice_visibility(&mut self, client_status: Status, cx: &mut Context) { - let active_provider = LanguageModelRegistry::read_global(cx).active_provider(); + let model = LanguageModelRegistry::read_global(cx).default_model(); // If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is // the provider, we want to show a nudge to sign in. let show_zed_ai_notice = client_status.is_signed_out() - && active_provider.map_or(true, |provider| provider.id().0 == ZED_CLOUD_PROVIDER_ID); + && model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID); self.show_zed_ai_notice = show_zed_ai_notice; cx.notify(); @@ -541,8 +544,8 @@ impl AssistantPanel { } let Some(new_provider_id) = LanguageModelRegistry::read_global(cx) - .active_provider() - .map(|p| p.id()) + .default_model() + .map(|default| default.provider.id()) else { return; }; @@ -568,7 +571,9 @@ impl AssistantPanel { return; } - let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else { + let Some(ConfiguredModel { provider, .. }) = + LanguageModelRegistry::read_global(cx).default_model() + else { return; }; @@ -976,8 +981,8 @@ impl AssistantPanel { |this, _, event: &ConfigurationViewEvent, window, cx| match event { ConfigurationViewEvent::NewProviderContextEditor(provider) => { if LanguageModelRegistry::read_global(cx) - .active_provider() - .map_or(true, |p| p.id() != provider.id()) + .default_model() + .map_or(true, |default| default.provider.id() != provider.id()) { if let Some(model) = provider.default_model(cx) { update_settings_file::( @@ -1155,8 +1160,8 @@ impl AssistantPanel { fn is_authenticated(&mut self, cx: &mut Context) -> bool { LanguageModelRegistry::read_global(cx) - .active_provider() - .map_or(false, |provider| provider.is_authenticated(cx)) + .default_model() + .map_or(false, |default| default.provider.is_authenticated(cx)) } fn authenticate( @@ -1164,8 +1169,8 @@ impl AssistantPanel { cx: &mut Context, ) -> Option>> { LanguageModelRegistry::read_global(cx) - .active_provider() - .map_or(None, |provider| Some(provider.authenticate(cx))) + .default_model() + .map_or(None, |default| Some(default.provider.authenticate(cx))) } fn restart_context_servers( diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 0ad67c5923..9a64528a70 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -34,8 +34,8 @@ use gpui::{ }; use language::{Buffer, IndentKind, Point, Selection, TransactionId, line_diff}; use language_model::{ - LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelTextStream, Role, report_assistant_event, + ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event, }; use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; use multi_buffer::MultiBufferRow; @@ -312,7 +312,9 @@ impl InlineAssistant { start..end, )); - if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { + if let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).default_model() + { self.telemetry.report_assistant_event(AssistantEvent { conversation_id: None, kind: AssistantKind::Inline, @@ -877,7 +879,9 @@ impl InlineAssistant { let active_alternative = assist.codegen.read(cx).active_alternative().clone(); let message_id = active_alternative.read(cx).message_id.clone(); - if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { + if let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).default_model() + { let language_name = assist.editor.upgrade().and_then(|editor| { let multibuffer = editor.read(cx).buffer().read(cx); let multibuffer_snapshot = multibuffer.snapshot(cx); @@ -1629,8 +1633,8 @@ impl Render for PromptEditor { format!( "Using {}", LanguageModelRegistry::read_global(cx) - .active_model() - .map(|model| model.name().0) + .default_model() + .map(|default| default.model.name().0) .unwrap_or_else(|| "No model selected".into()), ), None, @@ -2077,7 +2081,7 @@ impl PromptEditor { let disabled = matches!(codegen.status(cx), CodegenStatus::Idle); let model_registry = LanguageModelRegistry::read_global(cx); - let default_model = model_registry.active_model(); + let default_model = model_registry.default_model().map(|default| default.model); let alternative_models = model_registry.inline_alternative_models(); let get_model_name = |index: usize| -> String { @@ -2183,7 +2187,9 @@ impl PromptEditor { } fn render_token_count(&self, cx: &mut Context) -> Option { - let model = LanguageModelRegistry::read_global(cx).active_model()?; + let model = LanguageModelRegistry::read_global(cx) + .default_model()? + .model; let token_counts = self.token_counts?; let max_token_count = model.max_token_count(); @@ -2638,8 +2644,9 @@ impl Codegen { } let primary_model = LanguageModelRegistry::read_global(cx) - .active_model() - .context("no active model")?; + .default_model() + .context("no active model")? + .model; for (model, alternative) in iter::once(primary_model) .chain(alternative_models) @@ -2863,7 +2870,9 @@ impl CodegenAlternative { assistant_panel_context: Option, cx: &App, ) -> BoxFuture<'static, Result> { - if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { + if let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).inline_assistant_model() + { let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx); match request { Ok(request) => { diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 3c561cf1d8..8a0599cccb 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -16,8 +16,8 @@ use gpui::{ }; use language::Buffer; use language_model::{ - LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, - report_assistant_event, + ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, + Role, report_assistant_event, }; use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; use prompt_store::PromptBuilder; @@ -318,7 +318,9 @@ impl TerminalInlineAssistant { }) .log_err(); - if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { + if let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).inline_assistant_model() + { let codegen = assist.codegen.read(cx); let executor = cx.background_executor().clone(); report_assistant_event( @@ -652,8 +654,8 @@ impl Render for PromptEditor { format!( "Using {}", LanguageModelRegistry::read_global(cx) - .active_model() - .map(|model| model.name().0) + .inline_assistant_model() + .map(|inline_assistant| inline_assistant.model.name().0) .unwrap_or_else(|| "No model selected".into()), ), None, @@ -822,7 +824,9 @@ impl PromptEditor { fn count_tokens(&mut self, cx: &mut Context) { let assist_id = self.id; - let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).inline_assistant_model() + else { return; }; self.pending_token_count = cx.spawn(async move |this, cx| { @@ -980,7 +984,9 @@ impl PromptEditor { } fn render_token_count(&self, cx: &mut Context) -> Option { - let model = LanguageModelRegistry::read_global(cx).active_model()?; + let model = LanguageModelRegistry::read_global(cx) + .inline_assistant_model()? + .model; let token_count = self.token_count?; let max_token_count = model.max_token_count(); @@ -1131,7 +1137,9 @@ impl Codegen { } pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context) { - let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).inline_assistant_model() + else { return; }; diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index 4b2747d9ad..7196868435 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -1272,7 +1272,7 @@ impl AssistantContext { // Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit), // because otherwise you see in the UI that your empty message has a bunch of tokens already used. let request = self.to_completion_request(RequestType::Chat, cx); - let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { return; }; let debounce = self.token_count.is_some(); @@ -1284,10 +1284,12 @@ impl AssistantContext { .await; } - let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?; + let token_count = cx + .update(|cx| model.model.count_tokens(request, cx))? + .await?; this.update(cx, |this, cx| { this.token_count = Some(token_count); - this.start_cache_warming(&model, cx); + this.start_cache_warming(&model.model, cx); cx.notify() }) } @@ -2304,14 +2306,16 @@ impl AssistantContext { cx: &mut Context, ) -> Option { let model_registry = LanguageModelRegistry::read_global(cx); - let provider = model_registry.active_provider()?; - let model = model_registry.active_model()?; + let model = model_registry.default_model()?; let last_message_id = self.get_last_valid_message_id(cx)?; - if !provider.is_authenticated(cx) { + if !model.provider.is_authenticated(cx) { log::info!("completion provider has no credentials"); return None; } + + let model = model.model; + // Compute which messages to cache, including the last one. self.mark_cache_anchors(&model.cache_configuration(), false, cx); @@ -2940,15 +2944,12 @@ impl AssistantContext { } pub fn summarize(&mut self, replace_old: bool, cx: &mut Context) { - let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else { - return; - }; - let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { return; }; if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) { - if !provider.is_authenticated(cx) { + if !model.provider.is_authenticated(cx) { return; } @@ -2964,7 +2965,7 @@ impl AssistantContext { self.pending_summary = cx.spawn(async move |this, cx| { async move { - let stream = model.stream_completion_text(request, &cx); + let stream = model.model.stream_completion_text(request, &cx); let mut messages = stream.await?; let mut replaced = !replace_old; diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index 121ee9345d..1c7396bd92 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -384,7 +384,9 @@ impl ContextEditor { window: &mut Window, cx: &mut Context, ) { - let provider = LanguageModelRegistry::read_global(cx).active_provider(); + let provider = LanguageModelRegistry::read_global(cx) + .default_model() + .map(|default| default.provider); if provider .as_ref() .map_or(false, |provider| provider.must_accept_terms(cx)) @@ -2395,13 +2397,13 @@ impl ContextEditor { None => (ButtonStyle::Filled, None), }; - let provider = LanguageModelRegistry::read_global(cx).active_provider(); + let model = LanguageModelRegistry::read_global(cx).default_model(); let has_configuration_error = configuration_error(cx).is_some(); let needs_to_accept_terms = self.show_accept_terms - && provider + && model .as_ref() - .map_or(false, |provider| provider.must_accept_terms(cx)); + .map_or(false, |model| model.provider.must_accept_terms(cx)); let disabled = has_configuration_error || needs_to_accept_terms; ButtonLike::new("send_button") @@ -2454,7 +2456,9 @@ impl ContextEditor { None => (ButtonStyle::Filled, None), }; - let provider = LanguageModelRegistry::read_global(cx).active_provider(); + let provider = LanguageModelRegistry::read_global(cx) + .default_model() + .map(|default| default.provider); let has_configuration_error = configuration_error(cx).is_some(); let needs_to_accept_terms = self.show_accept_terms @@ -2500,7 +2504,9 @@ impl ContextEditor { } fn render_language_model_selector(&self, cx: &mut Context) -> impl IntoElement { - let active_model = LanguageModelRegistry::read_global(cx).active_model(); + let active_model = LanguageModelRegistry::read_global(cx) + .default_model() + .map(|default| default.model); let focus_handle = self.editor().focus_handle(cx).clone(); let model_name = match active_model { Some(model) => model.name().0, @@ -3020,7 +3026,9 @@ impl EventEmitter for ContextEditor {} impl Render for ContextEditor { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let provider = LanguageModelRegistry::read_global(cx).active_provider(); + let provider = LanguageModelRegistry::read_global(cx) + .default_model() + .map(|default| default.provider); let accept_terms = if self.show_accept_terms { provider.as_ref().and_then(|provider| { provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx) @@ -3616,7 +3624,9 @@ enum TokenState { fn token_state(context: &Entity, cx: &App) -> Option { const WARNING_TOKEN_THRESHOLD: f32 = 0.8; - let model = LanguageModelRegistry::read_global(cx).active_model()?; + let model = LanguageModelRegistry::read_global(cx) + .default_model()? + .model; let token_count = context.read(cx).token_count()?; let max_token_count = model.max_token_count(); @@ -3669,16 +3679,16 @@ pub enum ConfigurationError { } fn configuration_error(cx: &App) -> Option { - let provider = LanguageModelRegistry::read_global(cx).active_provider(); - let is_authenticated = provider + let model = LanguageModelRegistry::read_global(cx).default_model(); + let is_authenticated = model .as_ref() - .map_or(false, |provider| provider.is_authenticated(cx)); + .map_or(false, |model| model.provider.is_authenticated(cx)); - if provider.is_some() && is_authenticated { + if model.is_some() && is_authenticated { return None; } - if provider.is_none() { + if model.is_none() { return Some(ConfigurationError::NoProvider); } diff --git a/crates/assistant_eval/src/headless_assistant.rs b/crates/assistant_eval/src/headless_assistant.rs index 4f82604d48..d86bf253ba 100644 --- a/crates/assistant_eval/src/headless_assistant.rs +++ b/crates/assistant_eval/src/headless_assistant.rs @@ -156,10 +156,10 @@ impl HeadlessAssistant { } if thread.read(cx).all_tools_finished() { let model_registry = LanguageModelRegistry::read_global(cx); - if let Some(model) = model_registry.active_model() { + if let Some(model) = model_registry.default_model() { thread.update(cx, |thread, cx| { thread.attach_tool_results(cx); - thread.send_to_model(model, RequestKind::Chat, cx); + thread.send_to_model(model.model, RequestKind::Chat, cx); }); } else { println!( diff --git a/crates/assistant_eval/src/main.rs b/crates/assistant_eval/src/main.rs index 99864a1dd3..1c13e5f16c 100644 --- a/crates/assistant_eval/src/main.rs +++ b/crates/assistant_eval/src/main.rs @@ -37,9 +37,6 @@ struct Args { /// Name of the model (default: "claude-3-7-sonnet-latest") #[arg(long, default_value = "claude-3-7-sonnet-latest")] model_name: String, - /// Name of the editor model (default: value of `--model_name`). - #[arg(long)] - editor_model_name: Option, /// Name of the judge model (default: value of `--model_name`). #[arg(long)] judge_model_name: Option, @@ -79,11 +76,6 @@ fn main() { let app_state = headless_assistant::init(cx); let model = find_model(&args.model_name, cx).unwrap(); - let editor_model = if let Some(model_name) = &args.editor_model_name { - find_model(model_name, cx).unwrap() - } else { - model.clone() - }; let judge_model = if let Some(model_name) = &args.judge_model_name { find_model(model_name, cx).unwrap() } else { @@ -91,12 +83,10 @@ fn main() { }; LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.set_active_model(Some(model.clone()), cx); - registry.set_editor_model(Some(editor_model.clone()), cx); + registry.set_default_model(Some(model.clone()), cx); }); let model_provider_id = model.provider_id(); - let editor_model_provider_id = editor_model.provider_id(); let judge_model_provider_id = judge_model.provider_id(); let framework_path_clone = framework_path.clone(); @@ -110,10 +100,6 @@ fn main() { .unwrap() .await .unwrap(); - cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx)) - .unwrap() - .await - .unwrap(); cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx)) .unwrap() .await diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index e3026fc118..4f7d4e2395 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -77,7 +77,9 @@ pub struct AssistantSettings { pub default_width: Pixels, pub default_height: Pixels, pub default_model: LanguageModelSelection, - pub editor_model: LanguageModelSelection, + pub inline_assistant_model: Option, + pub commit_message_model: Option, + pub thread_summary_model: Option, pub inline_alternatives: Vec, pub using_outdated_settings_version: bool, pub enable_experimental_live_diffs: bool, @@ -95,13 +97,25 @@ impl AssistantSettings { cx.is_staff() || self.enable_experimental_live_diffs } + + pub fn set_inline_assistant_model(&mut self, provider: String, model: String) { + self.inline_assistant_model = Some(LanguageModelSelection { provider, model }); + } + + pub fn set_commit_message_model(&mut self, provider: String, model: String) { + self.commit_message_model = Some(LanguageModelSelection { provider, model }); + } + + pub fn set_thread_summary_model(&mut self, provider: String, model: String) { + self.thread_summary_model = Some(LanguageModelSelection { provider, model }); + } } /// Assistant panel settings #[derive(Clone, Serialize, Deserialize, Debug)] #[serde(untagged)] pub enum AssistantSettingsContent { - Versioned(VersionedAssistantSettingsContent), + Versioned(Box), Legacy(LegacyAssistantSettingsContent), } @@ -121,14 +135,14 @@ impl JsonSchema for AssistantSettingsContent { impl Default for AssistantSettingsContent { fn default() -> Self { - Self::Versioned(VersionedAssistantSettingsContent::default()) + Self::Versioned(Box::new(VersionedAssistantSettingsContent::default())) } } impl AssistantSettingsContent { pub fn is_version_outdated(&self) -> bool { match self { - AssistantSettingsContent::Versioned(settings) => match settings { + AssistantSettingsContent::Versioned(settings) => match **settings { VersionedAssistantSettingsContent::V1(_) => true, VersionedAssistantSettingsContent::V2(_) => false, }, @@ -138,8 +152,8 @@ impl AssistantSettingsContent { fn upgrade(&self) -> AssistantSettingsContentV2 { match self { - AssistantSettingsContent::Versioned(settings) => match settings { - VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 { + AssistantSettingsContent::Versioned(settings) => match **settings { + VersionedAssistantSettingsContent::V1(ref settings) => AssistantSettingsContentV2 { enabled: settings.enabled, button: settings.button, dock: settings.dock, @@ -186,7 +200,9 @@ impl AssistantSettingsContent { }) } }), - editor_model: None, + inline_assistant_model: None, + commit_message_model: None, + thread_summary_model: None, inline_alternatives: None, enable_experimental_live_diffs: None, default_profile: None, @@ -194,7 +210,7 @@ impl AssistantSettingsContent { always_allow_tool_actions: None, notify_when_agent_waiting: None, }, - VersionedAssistantSettingsContent::V2(settings) => settings.clone(), + VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(), }, AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 { enabled: None, @@ -211,7 +227,9 @@ impl AssistantSettingsContent { .id() .to_string(), }), - editor_model: None, + inline_assistant_model: None, + commit_message_model: None, + thread_summary_model: None, inline_alternatives: None, enable_experimental_live_diffs: None, default_profile: None, @@ -224,11 +242,11 @@ impl AssistantSettingsContent { pub fn set_dock(&mut self, dock: AssistantDockPosition) { match self { - AssistantSettingsContent::Versioned(settings) => match settings { - VersionedAssistantSettingsContent::V1(settings) => { + AssistantSettingsContent::Versioned(settings) => match **settings { + VersionedAssistantSettingsContent::V1(ref mut settings) => { settings.dock = Some(dock); } - VersionedAssistantSettingsContent::V2(settings) => { + VersionedAssistantSettingsContent::V2(ref mut settings) => { settings.dock = Some(dock); } }, @@ -243,77 +261,79 @@ impl AssistantSettingsContent { let provider = language_model.provider_id().0.to_string(); match self { - AssistantSettingsContent::Versioned(settings) => match settings { - VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() { - "zed.dev" => { - log::warn!("attempted to set zed.dev model on outdated settings"); - } - "anthropic" => { - let api_url = match &settings.provider { - Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => { - api_url.clone() - } - _ => None, - }; - settings.provider = Some(AssistantProviderContentV1::Anthropic { - default_model: AnthropicModel::from_id(&model).ok(), - api_url, - }); - } - "ollama" => { - let api_url = match &settings.provider { - Some(AssistantProviderContentV1::Ollama { api_url, .. }) => { - api_url.clone() - } - _ => None, - }; - settings.provider = Some(AssistantProviderContentV1::Ollama { - default_model: Some(ollama::Model::new(&model, None, None)), - api_url, - }); - } - "lmstudio" => { - let api_url = match &settings.provider { - Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => { - api_url.clone() - } - _ => None, - }; - settings.provider = Some(AssistantProviderContentV1::LmStudio { - default_model: Some(lmstudio::Model::new(&model, None, None)), - api_url, - }); - } - "openai" => { - let (api_url, available_models) = match &settings.provider { - Some(AssistantProviderContentV1::OpenAi { + AssistantSettingsContent::Versioned(settings) => match **settings { + VersionedAssistantSettingsContent::V1(ref mut settings) => { + match provider.as_ref() { + "zed.dev" => { + log::warn!("attempted to set zed.dev model on outdated settings"); + } + "anthropic" => { + let api_url = match &settings.provider { + Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => { + api_url.clone() + } + _ => None, + }; + settings.provider = Some(AssistantProviderContentV1::Anthropic { + default_model: AnthropicModel::from_id(&model).ok(), + api_url, + }); + } + "ollama" => { + let api_url = match &settings.provider { + Some(AssistantProviderContentV1::Ollama { api_url, .. }) => { + api_url.clone() + } + _ => None, + }; + settings.provider = Some(AssistantProviderContentV1::Ollama { + default_model: Some(ollama::Model::new(&model, None, None)), + api_url, + }); + } + "lmstudio" => { + let api_url = match &settings.provider { + Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => { + api_url.clone() + } + _ => None, + }; + settings.provider = Some(AssistantProviderContentV1::LmStudio { + default_model: Some(lmstudio::Model::new(&model, None, None)), + api_url, + }); + } + "openai" => { + let (api_url, available_models) = match &settings.provider { + Some(AssistantProviderContentV1::OpenAi { + api_url, + available_models, + .. + }) => (api_url.clone(), available_models.clone()), + _ => (None, None), + }; + settings.provider = Some(AssistantProviderContentV1::OpenAi { + default_model: OpenAiModel::from_id(&model).ok(), api_url, available_models, - .. - }) => (api_url.clone(), available_models.clone()), - _ => (None, None), - }; - settings.provider = Some(AssistantProviderContentV1::OpenAi { - default_model: OpenAiModel::from_id(&model).ok(), - api_url, - available_models, - }); + }); + } + "deepseek" => { + let api_url = match &settings.provider { + Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => { + api_url.clone() + } + _ => None, + }; + settings.provider = Some(AssistantProviderContentV1::DeepSeek { + default_model: DeepseekModel::from_id(&model).ok(), + api_url, + }); + } + _ => {} } - "deepseek" => { - let api_url = match &settings.provider { - Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => { - api_url.clone() - } - _ => None, - }; - settings.provider = Some(AssistantProviderContentV1::DeepSeek { - default_model: DeepseekModel::from_id(&model).ok(), - api_url, - }); - } - _ => {} - }, - VersionedAssistantSettingsContent::V2(settings) => { + } + VersionedAssistantSettingsContent::V2(ref mut settings) => { settings.default_model = Some(LanguageModelSelection { provider, model }); } }, @@ -325,23 +345,48 @@ impl AssistantSettingsContent { } } + pub fn set_inline_assistant_model(&mut self, provider: String, model: String) { + if let AssistantSettingsContent::Versioned(boxed) = self { + if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed { + settings.inline_assistant_model = Some(LanguageModelSelection { provider, model }); + } + } + } + + pub fn set_commit_message_model(&mut self, provider: String, model: String) { + if let AssistantSettingsContent::Versioned(boxed) = self { + if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed { + settings.commit_message_model = Some(LanguageModelSelection { provider, model }); + } + } + } + + pub fn set_thread_summary_model(&mut self, provider: String, model: String) { + if let AssistantSettingsContent::Versioned(boxed) = self { + if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed { + settings.thread_summary_model = Some(LanguageModelSelection { provider, model }); + } + } + } + pub fn set_always_allow_tool_actions(&mut self, allow: bool) { - let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) = - self - else { + let AssistantSettingsContent::Versioned(boxed) = self else { return; }; - settings.always_allow_tool_actions = Some(allow); + + if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed { + settings.always_allow_tool_actions = Some(allow); + } } pub fn set_profile(&mut self, profile_id: AgentProfileId) { - let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) = - self - else { + let AssistantSettingsContent::Versioned(boxed) = self else { return; }; - settings.default_profile = Some(profile_id); + if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed { + settings.default_profile = Some(profile_id); + } } pub fn create_profile( @@ -349,37 +394,37 @@ impl AssistantSettingsContent { profile_id: AgentProfileId, profile: AgentProfile, ) -> Result<()> { - let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) = - self - else { + let AssistantSettingsContent::Versioned(boxed) = self else { return Ok(()); }; - let profiles = settings.profiles.get_or_insert_default(); - if profiles.contains_key(&profile_id) { - bail!("profile with ID '{profile_id}' already exists"); - } + if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed { + let profiles = settings.profiles.get_or_insert_default(); + if profiles.contains_key(&profile_id) { + bail!("profile with ID '{profile_id}' already exists"); + } - profiles.insert( - profile_id, - AgentProfileContent { - name: profile.name.into(), - tools: profile.tools, - enable_all_context_servers: Some(profile.enable_all_context_servers), - context_servers: profile - .context_servers - .into_iter() - .map(|(server_id, preset)| { - ( - server_id, - ContextServerPresetContent { - tools: preset.tools, - }, - ) - }) - .collect(), - }, - ); + profiles.insert( + profile_id, + AgentProfileContent { + name: profile.name.into(), + tools: profile.tools, + enable_all_context_servers: Some(profile.enable_all_context_servers), + context_servers: profile + .context_servers + .into_iter() + .map(|(server_id, preset)| { + ( + server_id, + ContextServerPresetContent { + tools: preset.tools, + }, + ) + }) + .collect(), + }, + ); + } Ok(()) } @@ -403,7 +448,9 @@ impl Default for VersionedAssistantSettingsContent { default_width: None, default_height: None, default_model: None, - editor_model: None, + inline_assistant_model: None, + commit_message_model: None, + thread_summary_model: None, inline_alternatives: None, enable_experimental_live_diffs: None, default_profile: None, @@ -436,10 +483,14 @@ pub struct AssistantSettingsContentV2 { /// /// Default: 320 default_height: Option, - /// The default model to use when creating new chats. + /// The default model to use when creating new chats and for other features when a specific model is not specified. default_model: Option, - /// The model to use when applying edits from the assistant. - editor_model: Option, + /// Model to use for the inline assistant. Defaults to default_model when not specified. + inline_assistant_model: Option, + /// Model to use for generating git commit messages. Defaults to default_model when not specified. + commit_message_model: Option, + /// Model to use for generating thread summaries. Defaults to default_model when not specified. + thread_summary_model: Option, /// Additional models with which to generate alternatives when performing inline assists. inline_alternatives: Option>, /// Enable experimental live diffs in the assistant panel. @@ -601,7 +652,15 @@ impl Settings for AssistantSettings { value.default_height.map(Into::into), ); merge(&mut settings.default_model, value.default_model); - merge(&mut settings.editor_model, value.editor_model); + settings.inline_assistant_model = value + .inline_assistant_model + .or(settings.inline_assistant_model.take()); + settings.commit_message_model = value + .commit_message_model + .or(settings.commit_message_model.take()); + settings.thread_summary_model = value + .thread_summary_model + .or(settings.thread_summary_model.take()); merge(&mut settings.inline_alternatives, value.inline_alternatives); merge( &mut settings.enable_experimental_live_diffs, @@ -692,16 +751,15 @@ mod tests { settings::SettingsStore::global(cx).update_settings_file::( fs.clone(), |settings, _| { - *settings = AssistantSettingsContent::Versioned( + *settings = AssistantSettingsContent::Versioned(Box::new( VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 { default_model: Some(LanguageModelSelection { provider: "test-provider".into(), model: "gpt-99".into(), }), - editor_model: Some(LanguageModelSelection { - provider: "test-provider".into(), - model: "gpt-99".into(), - }), + inline_assistant_model: None, + commit_message_model: None, + thread_summary_model: None, inline_alternatives: None, enabled: None, button: None, @@ -714,7 +772,7 @@ mod tests { always_allow_tool_actions: None, notify_when_agent_waiting: None, }), - ) + )) }, ); }); diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index 6597f89943..d5933bfa02 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -9,7 +9,7 @@ use collections::HashSet; use edit_action::{EditAction, EditActionParser, edit_model_prompt}; use futures::{SinkExt, StreamExt, channel::mpsc}; use gpui::{App, AppContext, AsyncApp, Entity, Task}; -use language_model::LanguageModelToolSchemaFormat; +use language_model::{ConfiguredModel, LanguageModelToolSchemaFormat}; use language_model::{ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, }; @@ -205,8 +205,8 @@ impl EditToolRequest { cx: &mut App, ) -> Task> { let model_registry = LanguageModelRegistry::read_global(cx); - let Some(model) = model_registry.editor_model() else { - return Task::ready(Err(anyhow!("No editor model configured"))); + let Some(ConfiguredModel { model, .. }) = model_registry.default_model() else { + return Task::ready(Err(anyhow!("No model configured"))); }; let mut messages = messages.to_vec(); diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 2d662a5a2c..7853784b57 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -37,7 +37,8 @@ use gpui::{ use itertools::Itertools; use language::{Buffer, File}; use language_model::{ - LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, + ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, Role, }; use menu::{Confirm, SecondaryConfirm, SelectFirst, SelectLast, SelectNext, SelectPrevious}; use multi_buffer::ExcerptInfo; @@ -3764,8 +3765,9 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option, - editor_model: Option, + default_model: Option, + inline_assistant_model: Option, + commit_message_model: Option, + thread_summary_model: Option, providers: BTreeMap>, inline_alternatives: Vec>, } -pub struct ActiveModel { - provider: Arc, - model: Option>, +#[derive(Clone)] +pub struct ConfiguredModel { + pub provider: Arc, + pub model: Arc, } pub enum Event { - ActiveModelChanged, - EditorModelChanged, + DefaultModelChanged, + InlineAssistantModelChanged, + CommitMessageModelChanged, + ThreadSummaryModelChanged, ProviderStateChanged, AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), @@ -54,7 +59,7 @@ impl LanguageModelRegistry { let mut registry = Self::default(); registry.register_provider(fake_provider.clone(), cx); let model = fake_provider.provided_models(cx)[0].clone(); - registry.set_active_model(Some(model), cx); + registry.set_default_model(Some(model), cx); registry }); cx.set_global(GlobalLanguageModelRegistry(registry)); @@ -114,7 +119,7 @@ impl LanguageModelRegistry { self.providers.get(id).cloned() } - pub fn select_active_model( + pub fn select_default_model( &mut self, provider: &LanguageModelProviderId, model_id: &LanguageModelId, @@ -126,11 +131,11 @@ impl LanguageModelRegistry { let models = provider.provided_models(cx); if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { - self.set_active_model(Some(model), cx); + self.set_default_model(Some(model), cx); } } - pub fn select_editor_model( + pub fn select_inline_assistant_model( &mut self, provider: &LanguageModelProviderId, model_id: &LanguageModelId, @@ -142,23 +147,43 @@ impl LanguageModelRegistry { let models = provider.provided_models(cx); if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { - self.set_editor_model(Some(model), cx); + self.set_inline_assistant_model(Some(model), cx); } } - pub fn set_active_provider( + pub fn select_commit_message_model( &mut self, - provider: Option>, + provider: &LanguageModelProviderId, + model_id: &LanguageModelId, cx: &mut Context, ) { - self.active_model = provider.map(|provider| ActiveModel { - provider, - model: None, - }); - cx.emit(Event::ActiveModelChanged); + let Some(provider) = self.provider(provider) else { + return; + }; + + let models = provider.provided_models(cx); + if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { + self.set_commit_message_model(Some(model), cx); + } } - pub fn set_active_model( + pub fn select_thread_summary_model( + &mut self, + provider: &LanguageModelProviderId, + model_id: &LanguageModelId, + cx: &mut Context, + ) { + let Some(provider) = self.provider(provider) else { + return; + }; + + let models = provider.provided_models(cx); + if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { + self.set_thread_summary_model(Some(model), cx); + } + } + + pub fn set_default_model( &mut self, model: Option>, cx: &mut Context, @@ -166,21 +191,18 @@ impl LanguageModelRegistry { if let Some(model) = model { let provider_id = model.provider_id(); if let Some(provider) = self.providers.get(&provider_id).cloned() { - self.active_model = Some(ActiveModel { - provider, - model: Some(model), - }); - cx.emit(Event::ActiveModelChanged); + self.default_model = Some(ConfiguredModel { provider, model }); + cx.emit(Event::DefaultModelChanged); } else { log::warn!("Active model's provider not found in registry"); } } else { - self.active_model = None; - cx.emit(Event::ActiveModelChanged); + self.default_model = None; + cx.emit(Event::DefaultModelChanged); } } - pub fn set_editor_model( + pub fn set_inline_assistant_model( &mut self, model: Option>, cx: &mut Context, @@ -188,35 +210,80 @@ impl LanguageModelRegistry { if let Some(model) = model { let provider_id = model.provider_id(); if let Some(provider) = self.providers.get(&provider_id).cloned() { - self.editor_model = Some(ActiveModel { - provider, - model: Some(model), - }); - cx.emit(Event::EditorModelChanged); + self.inline_assistant_model = Some(ConfiguredModel { provider, model }); + cx.emit(Event::InlineAssistantModelChanged); } else { - log::warn!("Active model's provider not found in registry"); + log::warn!("Inline assistant model's provider not found in registry"); } } else { - self.editor_model = None; - cx.emit(Event::EditorModelChanged); + self.inline_assistant_model = None; + cx.emit(Event::InlineAssistantModelChanged); } } - pub fn active_provider(&self) -> Option> { + pub fn set_commit_message_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + if let Some(model) = model { + let provider_id = model.provider_id(); + if let Some(provider) = self.providers.get(&provider_id).cloned() { + self.commit_message_model = Some(ConfiguredModel { provider, model }); + cx.emit(Event::CommitMessageModelChanged); + } else { + log::warn!("Commit message model's provider not found in registry"); + } + } else { + self.commit_message_model = None; + cx.emit(Event::CommitMessageModelChanged); + } + } + + pub fn set_thread_summary_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + if let Some(model) = model { + let provider_id = model.provider_id(); + if let Some(provider) = self.providers.get(&provider_id).cloned() { + self.thread_summary_model = Some(ConfiguredModel { provider, model }); + cx.emit(Event::ThreadSummaryModelChanged); + } else { + log::warn!("Thread summary model's provider not found in registry"); + } + } else { + self.thread_summary_model = None; + cx.emit(Event::ThreadSummaryModelChanged); + } + } + + pub fn default_model(&self) -> Option { #[cfg(debug_assertions)] if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { return None; } - Some(self.active_model.as_ref()?.provider.clone()) + self.default_model.clone() } - pub fn active_model(&self) -> Option> { - self.active_model.as_ref()?.model.clone() + pub fn inline_assistant_model(&self) -> Option { + self.inline_assistant_model + .clone() + .or_else(|| self.default_model()) } - pub fn editor_model(&self) -> Option> { - self.editor_model.as_ref()?.model.clone() + pub fn commit_message_model(&self) -> Option { + self.commit_message_model + .clone() + .or_else(|| self.default_model()) + } + + pub fn thread_summary_model(&self) -> Option { + self.thread_summary_model + .clone() + .or_else(|| self.default_model()) } /// Selects and sets the inline alternatives for language models based on diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs index b6ef5a77eb..8cb8a1afb6 100644 --- a/crates/language_model_selector/src/language_model_selector.rs +++ b/crates/language_model_selector/src/language_model_selector.rs @@ -168,11 +168,11 @@ impl LanguageModelSelector { } fn get_active_model_index(cx: &App) -> usize { - let active_model = LanguageModelRegistry::read_global(cx).active_model(); + let active_model = LanguageModelRegistry::read_global(cx).default_model(); Self::all_models(cx) .iter() .position(|model_info| { - Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id()) + Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id()) }) .unwrap_or(0) } @@ -406,13 +406,10 @@ impl PickerDelegate for LanguageModelPickerDelegate { let model_info = self.filtered_models.get(ix)?; let provider_name: String = model_info.model.provider_name().0.clone().into(); - let active_provider_id = LanguageModelRegistry::read_global(cx) - .active_provider() - .map(|m| m.id()); + let active_model = LanguageModelRegistry::read_global(cx).default_model(); - let active_model_id = LanguageModelRegistry::read_global(cx) - .active_model() - .map(|m| m.id()); + let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); + let active_model_id = active_model.map(|m| m.model.id()); let is_selected = Some(model_info.model.provider_id()) == active_provider_id && Some(model_info.model.id()) == active_model_id; diff --git a/crates/prompt_library/src/prompt_library.rs b/crates/prompt_library/src/prompt_library.rs index 48d6c5e8f8..c2c1f3da60 100644 --- a/crates/prompt_library/src/prompt_library.rs +++ b/crates/prompt_library/src/prompt_library.rs @@ -9,7 +9,7 @@ use gpui::{ }; use language::{Buffer, LanguageRegistry, language_settings::SoftWrap}; use language_model::{ - LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, + ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, }; use picker::{Picker, PickerDelegate}; use release_channel::ReleaseChannel; @@ -777,7 +777,9 @@ impl PromptLibrary { }; let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor; - let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else { + let Some(ConfiguredModel { provider, .. }) = + LanguageModelRegistry::read_global(cx).inline_assistant_model() + else { return; }; @@ -880,7 +882,9 @@ impl PromptLibrary { } fn count_tokens(&mut self, prompt_id: PromptId, window: &mut Window, cx: &mut Context) { - let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else { + let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).default_model() + else { return; }; if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) { @@ -967,7 +971,9 @@ impl PromptLibrary { let prompt_metadata = self.store.metadata(prompt_id)?; let prompt_editor = &self.prompt_editors[&prompt_id]; let focus_handle = prompt_editor.body_editor.focus_handle(cx); - let model = LanguageModelRegistry::read_global(cx).active_model(); + let model = LanguageModelRegistry::read_global(cx) + .default_model() + .map(|default| default.model); let settings = ThemeSettings::get_global(cx); Some( diff --git a/docs/src/assistant/configuration.md b/docs/src/assistant/configuration.md index 90d215281c..bce5b78881 100644 --- a/docs/src/assistant/configuration.md +++ b/docs/src/assistant/configuration.md @@ -19,7 +19,8 @@ To further customize providers, you can use `settings.json` to do that as follow - [Configuring endpoints](#custom-endpoint) - [Configuring timeouts](#provider-timeout) -- [Configuring default model](#default-model) +- [Configuring models](#default-model) +- [Configuring feature-specific models](#feature-specific-models) - [Configuring alternative models for inline assists](#alternative-assists) ### Zed AI {#zed-ai} @@ -281,8 +282,24 @@ Example configuration for using X.ai Grok with Zed: "enabled": true, "default_model": { "provider": "zed.dev", + "model": "claude-3-7-sonnet" + }, + "editor_model": { + "provider": "openai", + "model": "gpt-4o" + }, + "inline_assistant_model": { + "provider": "anthropic", "model": "claude-3-5-sonnet" }, + "commit_message_model": { + "provider": "openai", + "model": "gpt-4o-mini" + }, + "thread_summary_model": { + "provider": "google", + "model": "gemini-1.5-flash" + }, "version": "2", "button": true, "default_width": 480, @@ -328,7 +345,7 @@ To do so, add the following to your Zed `settings.json`: Where `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`. -#### Configuring the default model {#default-model} +#### Configuring models {#default-model} The default model can be set via the model dropdown in the assistant panel's top-right corner. Selecting a model saves it as the default. You can also manually edit the `default_model` object in your settings: @@ -345,6 +362,47 @@ You can also manually edit the `default_model` object in your settings: } ``` +#### Feature-specific models {#feature-specific-models} + +> Currently only available in [Preview](https://zed.dev/releases/preview). + +Zed allows you to configure different models for specific features. +This provides flexibility to use more powerful models for certain tasks while using faster or more efficient models for others. + +If a feature-specific model is not set, it will fall back to using the default model, which is the one you set on the Agent Panel. + +You can configure the following feature-specific models: + +- Thread summary model: Used for generating thread summaries +- Inline assistant model: Used for the inline assistant feature +- Commit message model: Used for generating Git commit messages + +Example configuration: + +```json +{ + "assistant": { + "version": "2", + "default_model": { + "provider": "zed.dev", + "model": "claude-3-7-sonnet" + }, + "inline_assistant_model": { + "provider": "anthropic", + "model": "claude-3-5-sonnet" + }, + "commit_message_model": { + "provider": "openai", + "model": "gpt-4o-mini" + }, + "thread_summary_model": { + "provider": "google", + "model": "gemini-2.0-flash" + } + } +} +``` + #### Configuring alternative models for inline assists {#alternative-assists} You can configure additional models that will be used to perform inline assists in parallel. When you do this,