ai: Separate model settings for each feature (#28088)

Closes: https://github.com/zed-industries/zed/issues/20582

Allows users to select a specific model for each AI-powered feature:
- Agent panel
- Inline assistant
- Thread summarization
- Commit message generation

If unspecified for a given feature, it will use the `default_model`
setting.

Release Notes:

- Added support for configuring a specific model for each AI-powered
feature

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
Agus Zubiaga 2025-04-04 11:40:55 -03:00 committed by GitHub
parent cf0d1e4229
commit 43cb925a59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 670 additions and 381 deletions

View file

@ -21,7 +21,7 @@ use gpui::{
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
use project::ProjectItem as _;
use settings::{Settings as _, update_settings_file};
@ -606,7 +606,7 @@ impl ActiveThread {
if self.thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
if !canceled {
@ -814,21 +814,17 @@ impl ActiveThread {
}
});
let provider = LanguageModelRegistry::read_global(cx).active_provider();
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
{
cx.notify();
return;
}
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
if model.provider.must_accept_terms(cx) {
cx.notify();
return;
}
self.thread.update(cx, |thread, cx| {
thread.send_to_model(model, RequestKind::Chat, cx)
thread.send_to_model(model.model, RequestKind::Chat, cx)
});
cx.notify();
}

View file

@ -202,43 +202,43 @@ impl PickerDelegate for ToolPickerDelegate {
let default_profile = self.profile.clone();
let tool = tool.clone();
move |settings, _cx| match settings {
AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
settings,
)) => {
let profiles = settings.profiles.get_or_insert_default();
let profile =
profiles
.entry(profile_id)
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
});
AssistantSettingsContent::Versioned(boxed) => {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
let profiles = settings.profiles.get_or_insert_default();
let profile =
profiles
.entry(profile_id)
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
});
match tool.source {
ToolSource::Native => {
*profile.tools.entry(tool.name).or_default() = is_enabled;
}
ToolSource::ContextServer { id } => {
let preset = profile
.context_servers
.entry(id.clone().into())
.or_default();
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
match tool.source {
ToolSource::Native => {
*profile.tools.entry(tool.name).or_default() = is_enabled;
}
ToolSource::ContextServer { id } => {
let preset = profile
.context_servers
.entry(id.clone().into())
.or_default();
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
}
}
}
}

View file

@ -9,10 +9,17 @@ use settings::update_settings_file;
use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
model_type: ModelType,
}
impl AssistantModelSelector {
@ -20,6 +27,7 @@ impl AssistantModelSelector {
fs: Arc<dyn Fs>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
model_type: ModelType,
window: &mut Window,
cx: &mut App,
) -> Self {
@ -28,11 +36,32 @@ impl AssistantModelSelector {
let fs = fs.clone();
LanguageModelSelector::new(
move |model, cx| {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _cx| settings.set_model(model.clone()),
);
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
match model_type {
ModelType::Default => {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _cx| {
settings.set_model(model.clone());
},
);
}
ModelType::InlineAssistant => {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _cx| {
settings.set_inline_assistant_model(
provider.clone(),
model_id.clone(),
);
},
);
}
}
},
window,
cx,
@ -40,6 +69,7 @@ impl AssistantModelSelector {
}),
menu_handle,
focus_handle,
model_type,
}
}
@ -50,10 +80,16 @@ impl AssistantModelSelector {
impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let active_model = LanguageModelRegistry::read_global(cx).active_model();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
let focus_handle = self.focus_handle.clone();
let model_name = match active_model {
Some(model) => model.name().0,
let model_name = match model {
Some(model) => model.model.name().0,
_ => SharedString::from("No model selected"),
};

View file

@ -571,10 +571,8 @@ impl AssistantPanel {
match event {
AssistantConfigurationEvent::NewThread(provider) => {
if LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(true, |active_provider| {
active_provider.id() != provider.id()
})
.default_model()
.map_or(true, |model| model.provider.id() != provider.id())
{
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>(
@ -922,16 +920,18 @@ impl AssistantPanel {
}
fn configuration_error(&self, cx: &App) -> Option<ConfigurationError> {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return Some(ConfigurationError::NoProvider);
};
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
return Some(ConfigurationError::ProviderNotAuthenticated);
}
if provider.must_accept_terms(cx) {
return Some(ConfigurationError::ProviderPendingTermsAcceptance(provider));
if model.provider.must_accept_terms(cx) {
return Some(ConfigurationError::ProviderPendingTermsAcceptance(
model.provider,
));
}
None

View file

@ -156,8 +156,9 @@ impl BufferCodegen {
}
let primary_model = LanguageModelRegistry::read_global(cx)
.active_model()
.context("no active model")?;
.default_model()
.context("no active model")?
.model;
for (model, alternative) in iter::once(primary_model)
.chain(alternative_models)

View file

@ -239,8 +239,8 @@ impl InlineAssistant {
let is_authenticated = || {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(false, |provider| provider.is_authenticated(cx))
.inline_assistant_model()
.map_or(false, |model| model.provider.is_authenticated(cx))
};
let thread_store = workspace
@ -279,8 +279,8 @@ impl InlineAssistant {
cx.spawn_in(window, async move |_workspace, cx| {
let Some(task) = cx.update(|_, cx| {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(None, |provider| Some(provider.authenticate(cx)))
.inline_assistant_model()
.map_or(None, |model| Some(model.provider.authenticate(cx)))
})?
else {
let answer = cx
@ -401,14 +401,14 @@ impl InlineAssistant {
codegen_ranges.push(anchor_range);
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
self.telemetry.report_assistant_event(AssistantEvent {
conversation_id: None,
kind: AssistantKind::Inline,
phase: AssistantPhase::Invoked,
message_id: None,
model: model.telemetry_id(),
model_provider: model.provider_id().to_string(),
model: model.model.telemetry_id(),
model_provider: model.provider.id().to_string(),
response_latency: None,
error_message: None,
language_name: buffer.language().map(|language| language.name().to_proto()),
@ -976,7 +976,7 @@ impl InlineAssistant {
let active_alternative = assist.codegen.read(cx).active_alternative().clone();
let message_id = active_alternative.read(cx).message_id.clone();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
let language_name = assist.editor.upgrade().and_then(|editor| {
let multibuffer = editor.read(cx).buffer().read(cx);
let snapshot = multibuffer.snapshot(cx);
@ -996,15 +996,15 @@ impl InlineAssistant {
} else {
AssistantPhase::Accepted
},
model: model.telemetry_id(),
model_provider: model.provider_id().to_string(),
model: model.model.telemetry_id(),
model_provider: model.model.provider_id().to_string(),
response_latency: None,
error_message: None,
language_name: language_name.map(|name| name.to_proto()),
},
Some(self.telemetry.clone()),
cx.http_client(),
model.api_key(cx),
model.model.api_key(cx),
cx.background_executor(),
);
}

View file

@ -1,4 +1,4 @@
use crate::assistant_model_selector::AssistantModelSelector;
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
@ -582,7 +582,7 @@ impl<T: 'static> PromptEditor<T> {
let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
let model_registry = LanguageModelRegistry::read_global(cx);
let default_model = model_registry.active_model();
let default_model = model_registry.default_model().map(|default| default.model);
let alternative_models = model_registry.inline_alternative_models();
let get_model_name = |index: usize| -> String {
@ -890,6 +890,7 @@ impl PromptEditor<BufferCodegen> {
fs,
model_selector_menu_handle,
prompt_editor.focus_handle(cx),
ModelType::InlineAssistant,
window,
cx,
)
@ -1042,6 +1043,7 @@ impl PromptEditor<TerminalCodegen> {
fs,
model_selector_menu_handle.clone(),
prompt_editor.focus_handle(cx),
ModelType::InlineAssistant,
window,
cx,
)

View file

@ -1,5 +1,6 @@
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorStyle};
@ -10,7 +11,7 @@ use gpui::{
WeakEntity, linear_color_stop, linear_gradient, point,
};
use language::Buffer;
use language_model::LanguageModelRegistry;
use language_model::{ConfiguredModel, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
@ -139,6 +140,7 @@ impl MessageEditor {
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
ModelType::Default,
window,
cx,
)
@ -191,7 +193,7 @@ impl MessageEditor {
fn is_model_selected(&self, cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
.active_model()
.default_model()
.is_some()
}
@ -201,20 +203,16 @@ impl MessageEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
{
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
return;
};
if provider.must_accept_terms(cx) {
cx.notify();
return;
}
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.active_model() else {
return;
};
let user_message = self.editor.update(cx, |editor, cx| {
let text = editor.text(cx);
editor.clear(window, cx);

View file

@ -130,8 +130,8 @@ impl Render for ProfileSelector {
let model_registry = LanguageModelRegistry::read_global(cx);
let supports_tools = model_registry
.active_model()
.map_or(false, |model| model.supports_tools());
.default_model()
.map_or(false, |default| default.model.supports_tools());
let icon = match profile_id.as_str() {
"write" => IconName::Pencil,

View file

@ -2,7 +2,9 @@ use crate::inline_prompt_editor::CodegenStatus;
use client::telemetry::Telemetry;
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
use language_model::{LanguageModelRegistry, LanguageModelRequest, report_assistant_event};
use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, report_assistant_event,
};
use std::{sync::Arc, time::Instant};
use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
use terminal::Terminal;
@ -31,7 +33,9 @@ impl TerminalCodegen {
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};

View file

@ -13,8 +13,8 @@ use fs::Fs;
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
use language::Buffer;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
report_assistant_event,
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use prompt_store::PromptBuilder;
use std::sync::Arc;
@ -286,7 +286,9 @@ impl TerminalInlineAssistant {
})
.log_err();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
{
let codegen = assist.codegen.read(cx);
let executor = cx.background_executor().clone();
report_assistant_event(

View file

@ -14,10 +14,10 @@ use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, TokenUsage,
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
PaymentRequiredError, Role, StopReason, TokenUsage,
};
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
use project::{Project, Worktree};
@ -1250,14 +1250,11 @@ impl Thread {
}
pub fn summarize(&mut self, cx: &mut Context<Self>) {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
return;
};
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
return;
}
@ -1276,7 +1273,7 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
let stream = model.stream_completion_text(request, &cx);
let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut new_summary = String::new();
@ -1320,8 +1317,8 @@ impl Thread {
_ => {}
}
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let ConfiguredModel { model, provider } =
LanguageModelRegistry::read_global(cx).thread_summary_model()?;
if !provider.is_authenticated(cx) {
return None;
@ -1782,11 +1779,11 @@ impl Thread {
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.active_model() else {
let Some(model) = model_registry.default_model() else {
return TotalTokenUsage::default();
};
let max = model.max_token_count();
let max = model.model.max_token_count();
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")