ai: Separate model settings for each feature (#28088)

Closes: https://github.com/zed-industries/zed/issues/20582

Allows users to select a specific model for each AI-powered feature:
- Agent panel
- Inline assistant
- Thread summarization
- Commit message generation

If unspecified for a given feature, it will use the `default_model`
setting.

Release Notes:

- Added support for configuring a specific model for each AI-powered
feature

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
Agus Zubiaga 2025-04-04 11:40:55 -03:00 committed by GitHub
parent cf0d1e4229
commit 43cb925a59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 670 additions and 381 deletions

View file

@ -21,7 +21,7 @@ use gpui::{
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
use project::ProjectItem as _;
use settings::{Settings as _, update_settings_file};
@ -606,7 +606,7 @@ impl ActiveThread {
if self.thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
if !canceled {
@ -814,21 +814,17 @@ impl ActiveThread {
}
});
let provider = LanguageModelRegistry::read_global(cx).active_provider();
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
{
cx.notify();
return;
}
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
if model.provider.must_accept_terms(cx) {
cx.notify();
return;
}
self.thread.update(cx, |thread, cx| {
thread.send_to_model(model, RequestKind::Chat, cx)
thread.send_to_model(model.model, RequestKind::Chat, cx)
});
cx.notify();
}

View file

@ -202,43 +202,43 @@ impl PickerDelegate for ToolPickerDelegate {
let default_profile = self.profile.clone();
let tool = tool.clone();
move |settings, _cx| match settings {
AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
settings,
)) => {
let profiles = settings.profiles.get_or_insert_default();
let profile =
profiles
.entry(profile_id)
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
});
AssistantSettingsContent::Versioned(boxed) => {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
let profiles = settings.profiles.get_or_insert_default();
let profile =
profiles
.entry(profile_id)
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
});
match tool.source {
ToolSource::Native => {
*profile.tools.entry(tool.name).or_default() = is_enabled;
}
ToolSource::ContextServer { id } => {
let preset = profile
.context_servers
.entry(id.clone().into())
.or_default();
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
match tool.source {
ToolSource::Native => {
*profile.tools.entry(tool.name).or_default() = is_enabled;
}
ToolSource::ContextServer { id } => {
let preset = profile
.context_servers
.entry(id.clone().into())
.or_default();
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
}
}
}
}

View file

@ -9,10 +9,17 @@ use settings::update_settings_file;
use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
model_type: ModelType,
}
impl AssistantModelSelector {
@ -20,6 +27,7 @@ impl AssistantModelSelector {
fs: Arc<dyn Fs>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
model_type: ModelType,
window: &mut Window,
cx: &mut App,
) -> Self {
@ -28,11 +36,32 @@ impl AssistantModelSelector {
let fs = fs.clone();
LanguageModelSelector::new(
move |model, cx| {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _cx| settings.set_model(model.clone()),
);
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
match model_type {
ModelType::Default => {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _cx| {
settings.set_model(model.clone());
},
);
}
ModelType::InlineAssistant => {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _cx| {
settings.set_inline_assistant_model(
provider.clone(),
model_id.clone(),
);
},
);
}
}
},
window,
cx,
@ -40,6 +69,7 @@ impl AssistantModelSelector {
}),
menu_handle,
focus_handle,
model_type,
}
}
@ -50,10 +80,16 @@ impl AssistantModelSelector {
impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let active_model = LanguageModelRegistry::read_global(cx).active_model();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
let focus_handle = self.focus_handle.clone();
let model_name = match active_model {
Some(model) => model.name().0,
let model_name = match model {
Some(model) => model.model.name().0,
_ => SharedString::from("No model selected"),
};

View file

@ -571,10 +571,8 @@ impl AssistantPanel {
match event {
AssistantConfigurationEvent::NewThread(provider) => {
if LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(true, |active_provider| {
active_provider.id() != provider.id()
})
.default_model()
.map_or(true, |model| model.provider.id() != provider.id())
{
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>(
@ -922,16 +920,18 @@ impl AssistantPanel {
}
fn configuration_error(&self, cx: &App) -> Option<ConfigurationError> {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return Some(ConfigurationError::NoProvider);
};
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
return Some(ConfigurationError::ProviderNotAuthenticated);
}
if provider.must_accept_terms(cx) {
return Some(ConfigurationError::ProviderPendingTermsAcceptance(provider));
if model.provider.must_accept_terms(cx) {
return Some(ConfigurationError::ProviderPendingTermsAcceptance(
model.provider,
));
}
None

View file

@ -156,8 +156,9 @@ impl BufferCodegen {
}
let primary_model = LanguageModelRegistry::read_global(cx)
.active_model()
.context("no active model")?;
.default_model()
.context("no active model")?
.model;
for (model, alternative) in iter::once(primary_model)
.chain(alternative_models)

View file

@ -239,8 +239,8 @@ impl InlineAssistant {
let is_authenticated = || {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(false, |provider| provider.is_authenticated(cx))
.inline_assistant_model()
.map_or(false, |model| model.provider.is_authenticated(cx))
};
let thread_store = workspace
@ -279,8 +279,8 @@ impl InlineAssistant {
cx.spawn_in(window, async move |_workspace, cx| {
let Some(task) = cx.update(|_, cx| {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(None, |provider| Some(provider.authenticate(cx)))
.inline_assistant_model()
.map_or(None, |model| Some(model.provider.authenticate(cx)))
})?
else {
let answer = cx
@ -401,14 +401,14 @@ impl InlineAssistant {
codegen_ranges.push(anchor_range);
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
self.telemetry.report_assistant_event(AssistantEvent {
conversation_id: None,
kind: AssistantKind::Inline,
phase: AssistantPhase::Invoked,
message_id: None,
model: model.telemetry_id(),
model_provider: model.provider_id().to_string(),
model: model.model.telemetry_id(),
model_provider: model.provider.id().to_string(),
response_latency: None,
error_message: None,
language_name: buffer.language().map(|language| language.name().to_proto()),
@ -976,7 +976,7 @@ impl InlineAssistant {
let active_alternative = assist.codegen.read(cx).active_alternative().clone();
let message_id = active_alternative.read(cx).message_id.clone();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
let language_name = assist.editor.upgrade().and_then(|editor| {
let multibuffer = editor.read(cx).buffer().read(cx);
let snapshot = multibuffer.snapshot(cx);
@ -996,15 +996,15 @@ impl InlineAssistant {
} else {
AssistantPhase::Accepted
},
model: model.telemetry_id(),
model_provider: model.provider_id().to_string(),
model: model.model.telemetry_id(),
model_provider: model.model.provider_id().to_string(),
response_latency: None,
error_message: None,
language_name: language_name.map(|name| name.to_proto()),
},
Some(self.telemetry.clone()),
cx.http_client(),
model.api_key(cx),
model.model.api_key(cx),
cx.background_executor(),
);
}

View file

@ -1,4 +1,4 @@
use crate::assistant_model_selector::AssistantModelSelector;
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
@ -582,7 +582,7 @@ impl<T: 'static> PromptEditor<T> {
let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
let model_registry = LanguageModelRegistry::read_global(cx);
let default_model = model_registry.active_model();
let default_model = model_registry.default_model().map(|default| default.model);
let alternative_models = model_registry.inline_alternative_models();
let get_model_name = |index: usize| -> String {
@ -890,6 +890,7 @@ impl PromptEditor<BufferCodegen> {
fs,
model_selector_menu_handle,
prompt_editor.focus_handle(cx),
ModelType::InlineAssistant,
window,
cx,
)
@ -1042,6 +1043,7 @@ impl PromptEditor<TerminalCodegen> {
fs,
model_selector_menu_handle.clone(),
prompt_editor.focus_handle(cx),
ModelType::InlineAssistant,
window,
cx,
)

View file

@ -1,5 +1,6 @@
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorStyle};
@ -10,7 +11,7 @@ use gpui::{
WeakEntity, linear_color_stop, linear_gradient, point,
};
use language::Buffer;
use language_model::LanguageModelRegistry;
use language_model::{ConfiguredModel, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
@ -139,6 +140,7 @@ impl MessageEditor {
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
ModelType::Default,
window,
cx,
)
@ -191,7 +193,7 @@ impl MessageEditor {
fn is_model_selected(&self, cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
.active_model()
.default_model()
.is_some()
}
@ -201,20 +203,16 @@ impl MessageEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
{
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
return;
};
if provider.must_accept_terms(cx) {
cx.notify();
return;
}
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.active_model() else {
return;
};
let user_message = self.editor.update(cx, |editor, cx| {
let text = editor.text(cx);
editor.clear(window, cx);

View file

@ -130,8 +130,8 @@ impl Render for ProfileSelector {
let model_registry = LanguageModelRegistry::read_global(cx);
let supports_tools = model_registry
.active_model()
.map_or(false, |model| model.supports_tools());
.default_model()
.map_or(false, |default| default.model.supports_tools());
let icon = match profile_id.as_str() {
"write" => IconName::Pencil,

View file

@ -2,7 +2,9 @@ use crate::inline_prompt_editor::CodegenStatus;
use client::telemetry::Telemetry;
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
use language_model::{LanguageModelRegistry, LanguageModelRequest, report_assistant_event};
use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, report_assistant_event,
};
use std::{sync::Arc, time::Instant};
use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
use terminal::Terminal;
@ -31,7 +33,9 @@ impl TerminalCodegen {
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};

View file

@ -13,8 +13,8 @@ use fs::Fs;
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
use language::Buffer;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
report_assistant_event,
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use prompt_store::PromptBuilder;
use std::sync::Arc;
@ -286,7 +286,9 @@ impl TerminalInlineAssistant {
})
.log_err();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
{
let codegen = assist.codegen.read(cx);
let executor = cx.background_executor().clone();
report_assistant_event(

View file

@ -14,10 +14,10 @@ use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, TokenUsage,
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
PaymentRequiredError, Role, StopReason, TokenUsage,
};
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
use project::{Project, Worktree};
@ -1250,14 +1250,11 @@ impl Thread {
}
pub fn summarize(&mut self, cx: &mut Context<Self>) {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
return;
};
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
return;
}
@ -1276,7 +1273,7 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
let stream = model.stream_completion_text(request, &cx);
let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut new_summary = String::new();
@ -1320,8 +1317,8 @@ impl Thread {
_ => {}
}
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let ConfiguredModel { model, provider } =
LanguageModelRegistry::read_global(cx).thread_summary_model()?;
if !provider.is_authenticated(cx) {
return None;
@ -1782,11 +1779,11 @@ impl Thread {
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.active_model() else {
let Some(model) = model_registry.default_model() else {
return TotalTokenUsage::default();
};
let max = model.max_token_count();
let max = model.model.max_token_count();
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")

View file

@ -161,12 +161,38 @@ fn init_language_model_settings(cx: &mut App) {
fn update_active_language_model_from_settings(cx: &mut App) {
let settings = AssistantSettings::get_global(cx);
// Default model - used as fallback
let active_model_provider_name =
LanguageModelProviderId::from(settings.default_model.provider.clone());
let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
let editor_provider_name =
LanguageModelProviderId::from(settings.editor_model.provider.clone());
let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone());
// Inline assistant model
let inline_assistant_model = settings
.inline_assistant_model
.as_ref()
.unwrap_or(&settings.default_model);
let inline_assistant_provider_name =
LanguageModelProviderId::from(inline_assistant_model.provider.clone());
let inline_assistant_model_id = LanguageModelId::from(inline_assistant_model.model.clone());
// Commit message model
let commit_message_model = settings
.commit_message_model
.as_ref()
.unwrap_or(&settings.default_model);
let commit_message_provider_name =
LanguageModelProviderId::from(commit_message_model.provider.clone());
let commit_message_model_id = LanguageModelId::from(commit_message_model.model.clone());
// Thread summary model
let thread_summary_model = settings
.thread_summary_model
.as_ref()
.unwrap_or(&settings.default_model);
let thread_summary_provider_name =
LanguageModelProviderId::from(thread_summary_model.provider.clone());
let thread_summary_model_id = LanguageModelId::from(thread_summary_model.model.clone());
let inline_alternatives = settings
.inline_alternatives
.iter()
@ -177,9 +203,29 @@ fn update_active_language_model_from_settings(cx: &mut App) {
)
})
.collect::<Vec<_>>();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.select_active_model(&active_model_provider_name, &active_model_id, cx);
registry.select_editor_model(&editor_provider_name, &editor_model_id, cx);
// Set the default model
registry.select_default_model(&active_model_provider_name, &active_model_id, cx);
// Set the specific models
registry.select_inline_assistant_model(
&inline_assistant_provider_name,
&inline_assistant_model_id,
cx,
);
registry.select_commit_message_model(
&commit_message_provider_name,
&commit_message_model_id,
cx,
);
registry.select_thread_summary_model(
&thread_summary_provider_name,
&thread_summary_model_id,
cx,
);
// Set the alternatives
registry.select_inline_alternative_models(inline_alternatives, cx);
});
}

View file

@ -22,7 +22,8 @@ use gpui::{
};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
ZED_CLOUD_PROVIDER_ID,
};
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
@ -298,8 +299,10 @@ impl AssistantPanel {
&LanguageModelRegistry::global(cx),
window,
|this, _, event: &language_model::Event, window, cx| match event {
language_model::Event::ActiveModelChanged
| language_model::Event::EditorModelChanged => {
language_model::Event::DefaultModelChanged
| language_model::Event::InlineAssistantModelChanged
| language_model::Event::CommitMessageModelChanged
| language_model::Event::ThreadSummaryModelChanged => {
this.completion_provider_changed(window, cx);
}
language_model::Event::ProviderStateChanged => {
@ -468,12 +471,12 @@ impl AssistantPanel {
}
fn update_zed_ai_notice_visibility(&mut self, client_status: Status, cx: &mut Context<Self>) {
let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
let model = LanguageModelRegistry::read_global(cx).default_model();
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
// the provider, we want to show a nudge to sign in.
let show_zed_ai_notice = client_status.is_signed_out()
&& active_provider.map_or(true, |provider| provider.id().0 == ZED_CLOUD_PROVIDER_ID);
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
self.show_zed_ai_notice = show_zed_ai_notice;
cx.notify();
@ -541,8 +544,8 @@ impl AssistantPanel {
}
let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
.active_provider()
.map(|p| p.id())
.default_model()
.map(|default| default.provider.id())
else {
return;
};
@ -568,7 +571,9 @@ impl AssistantPanel {
return;
}
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
let Some(ConfiguredModel { provider, .. }) =
LanguageModelRegistry::read_global(cx).default_model()
else {
return;
};
@ -976,8 +981,8 @@ impl AssistantPanel {
|this, _, event: &ConfigurationViewEvent, window, cx| match event {
ConfigurationViewEvent::NewProviderContextEditor(provider) => {
if LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(true, |p| p.id() != provider.id())
.default_model()
.map_or(true, |default| default.provider.id() != provider.id())
{
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>(
@ -1155,8 +1160,8 @@ impl AssistantPanel {
fn is_authenticated(&mut self, cx: &mut Context<Self>) -> bool {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(false, |provider| provider.is_authenticated(cx))
.default_model()
.map_or(false, |default| default.provider.is_authenticated(cx))
}
fn authenticate(
@ -1164,8 +1169,8 @@ impl AssistantPanel {
cx: &mut Context<Self>,
) -> Option<Task<Result<(), AuthenticateError>>> {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(None, |provider| Some(provider.authenticate(cx)))
.default_model()
.map_or(None, |default| Some(default.provider.authenticate(cx)))
}
fn restart_context_servers(

View file

@ -34,8 +34,8 @@ use gpui::{
};
use language::{Buffer, IndentKind, Point, Selection, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelTextStream, Role, report_assistant_event,
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use multi_buffer::MultiBufferRow;
@ -312,7 +312,9 @@ impl InlineAssistant {
start..end,
));
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).default_model()
{
self.telemetry.report_assistant_event(AssistantEvent {
conversation_id: None,
kind: AssistantKind::Inline,
@ -877,7 +879,9 @@ impl InlineAssistant {
let active_alternative = assist.codegen.read(cx).active_alternative().clone();
let message_id = active_alternative.read(cx).message_id.clone();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).default_model()
{
let language_name = assist.editor.upgrade().and_then(|editor| {
let multibuffer = editor.read(cx).buffer().read(cx);
let multibuffer_snapshot = multibuffer.snapshot(cx);
@ -1629,8 +1633,8 @@ impl Render for PromptEditor {
format!(
"Using {}",
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.default_model()
.map(|default| default.model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
@ -2077,7 +2081,7 @@ impl PromptEditor {
let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
let model_registry = LanguageModelRegistry::read_global(cx);
let default_model = model_registry.active_model();
let default_model = model_registry.default_model().map(|default| default.model);
let alternative_models = model_registry.inline_alternative_models();
let get_model_name = |index: usize| -> String {
@ -2183,7 +2187,9 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx)
.default_model()?
.model;
let token_counts = self.token_counts?;
let max_token_count = model.max_token_count();
@ -2638,8 +2644,9 @@ impl Codegen {
}
let primary_model = LanguageModelRegistry::read_global(cx)
.active_model()
.context("no active model")?;
.default_model()
.context("no active model")?
.model;
for (model, alternative) in iter::once(primary_model)
.chain(alternative_models)
@ -2863,7 +2870,9 @@ impl CodegenAlternative {
assistant_panel_context: Option<LanguageModelRequest>,
cx: &App,
) -> BoxFuture<'static, Result<TokenCounts>> {
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
{
let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
match request {
Ok(request) => {

View file

@ -16,8 +16,8 @@ use gpui::{
};
use language::Buffer;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
report_assistant_event,
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use prompt_store::PromptBuilder;
@ -318,7 +318,9 @@ impl TerminalInlineAssistant {
})
.log_err();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
{
let codegen = assist.codegen.read(cx);
let executor = cx.background_executor().clone();
report_assistant_event(
@ -652,8 +654,8 @@ impl Render for PromptEditor {
format!(
"Using {}",
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.inline_assistant_model()
.map(|inline_assistant| inline_assistant.model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
@ -822,7 +824,9 @@ impl PromptEditor {
fn count_tokens(&mut self, cx: &mut Context<Self>) {
let assist_id = self.id;
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};
self.pending_token_count = cx.spawn(async move |this, cx| {
@ -980,7 +984,9 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx)
.inline_assistant_model()?
.model;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@ -1131,7 +1137,9 @@ impl Codegen {
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};

View file

@ -1272,7 +1272,7 @@ impl AssistantContext {
// Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit),
// because otherwise you see in the UI that your empty message has a bunch of tokens already used.
let request = self.to_completion_request(RequestType::Chat, cx);
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
let debounce = self.token_count.is_some();
@ -1284,10 +1284,12 @@ impl AssistantContext {
.await;
}
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
let token_count = cx
.update(|cx| model.model.count_tokens(request, cx))?
.await?;
this.update(cx, |this, cx| {
this.token_count = Some(token_count);
this.start_cache_warming(&model, cx);
this.start_cache_warming(&model.model, cx);
cx.notify()
})
}
@ -2304,14 +2306,16 @@ impl AssistantContext {
cx: &mut Context<Self>,
) -> Option<MessageAnchor> {
let model_registry = LanguageModelRegistry::read_global(cx);
let provider = model_registry.active_provider()?;
let model = model_registry.active_model()?;
let model = model_registry.default_model()?;
let last_message_id = self.get_last_valid_message_id(cx)?;
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
log::info!("completion provider has no credentials");
return None;
}
let model = model.model;
// Compute which messages to cache, including the last one.
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
@ -2940,15 +2944,12 @@ impl AssistantContext {
}
pub fn summarize(&mut self, replace_old: bool, cx: &mut Context<Self>) {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
return;
}
@ -2964,7 +2965,7 @@ impl AssistantContext {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
let stream = model.stream_completion_text(request, &cx);
let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut replaced = !replace_old;

View file

@ -384,7 +384,9 @@ impl ContextEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let provider = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.provider);
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
@ -2395,13 +2397,13 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let model = LanguageModelRegistry::read_global(cx).default_model();
let has_configuration_error = configuration_error(cx).is_some();
let needs_to_accept_terms = self.show_accept_terms
&& provider
&& model
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx));
.map_or(false, |model| model.provider.must_accept_terms(cx));
let disabled = has_configuration_error || needs_to_accept_terms;
ButtonLike::new("send_button")
@ -2454,7 +2456,9 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let provider = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.provider);
let has_configuration_error = configuration_error(cx).is_some();
let needs_to_accept_terms = self.show_accept_terms
@ -2500,7 +2504,9 @@ impl ContextEditor {
}
fn render_language_model_selector(&self, cx: &mut Context<Self>) -> impl IntoElement {
let active_model = LanguageModelRegistry::read_global(cx).active_model();
let active_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model);
let focus_handle = self.editor().focus_handle(cx).clone();
let model_name = match active_model {
Some(model) => model.name().0,
@ -3020,7 +3026,9 @@ impl EventEmitter<SearchEvent> for ContextEditor {}
impl Render for ContextEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let provider = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.provider);
let accept_terms = if self.show_accept_terms {
provider.as_ref().and_then(|provider| {
provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx)
@ -3616,7 +3624,9 @@ enum TokenState {
fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenState> {
const WARNING_TOKEN_THRESHOLD: f32 = 0.8;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx)
.default_model()?
.model;
let token_count = context.read(cx).token_count()?;
let max_token_count = model.max_token_count();
@ -3669,16 +3679,16 @@ pub enum ConfigurationError {
}
fn configuration_error(cx: &App) -> Option<ConfigurationError> {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let is_authenticated = provider
let model = LanguageModelRegistry::read_global(cx).default_model();
let is_authenticated = model
.as_ref()
.map_or(false, |provider| provider.is_authenticated(cx));
.map_or(false, |model| model.provider.is_authenticated(cx));
if provider.is_some() && is_authenticated {
if model.is_some() && is_authenticated {
return None;
}
if provider.is_none() {
if model.is_none() {
return Some(ConfigurationError::NoProvider);
}

View file

@ -156,10 +156,10 @@ impl HeadlessAssistant {
}
if thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
if let Some(model) = model_registry.default_model() {
thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
thread.send_to_model(model, RequestKind::Chat, cx);
thread.send_to_model(model.model, RequestKind::Chat, cx);
});
} else {
println!(

View file

@ -37,9 +37,6 @@ struct Args {
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
/// Name of the editor model (default: value of `--model_name`).
#[arg(long)]
editor_model_name: Option<String>,
/// Name of the judge model (default: value of `--model_name`).
#[arg(long)]
judge_model_name: Option<String>,
@ -79,11 +76,6 @@ fn main() {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
let editor_model = if let Some(model_name) = &args.editor_model_name {
find_model(model_name, cx).unwrap()
} else {
model.clone()
};
let judge_model = if let Some(model_name) = &args.judge_model_name {
find_model(model_name, cx).unwrap()
} else {
@ -91,12 +83,10 @@ fn main() {
};
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_active_model(Some(model.clone()), cx);
registry.set_editor_model(Some(editor_model.clone()), cx);
registry.set_default_model(Some(model.clone()), cx);
});
let model_provider_id = model.provider_id();
let editor_model_provider_id = editor_model.provider_id();
let judge_model_provider_id = judge_model.provider_id();
let framework_path_clone = framework_path.clone();
@ -110,10 +100,6 @@ fn main() {
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
.unwrap()
.await

View file

@ -77,7 +77,9 @@ pub struct AssistantSettings {
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: LanguageModelSelection,
pub editor_model: LanguageModelSelection,
pub inline_assistant_model: Option<LanguageModelSelection>,
pub commit_message_model: Option<LanguageModelSelection>,
pub thread_summary_model: Option<LanguageModelSelection>,
pub inline_alternatives: Vec<LanguageModelSelection>,
pub using_outdated_settings_version: bool,
pub enable_experimental_live_diffs: bool,
@ -95,13 +97,25 @@ impl AssistantSettings {
cx.is_staff() || self.enable_experimental_live_diffs
}
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
self.inline_assistant_model = Some(LanguageModelSelection { provider, model });
}
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
self.commit_message_model = Some(LanguageModelSelection { provider, model });
}
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
self.thread_summary_model = Some(LanguageModelSelection { provider, model });
}
}
/// Assistant panel settings
#[derive(Clone, Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum AssistantSettingsContent {
Versioned(VersionedAssistantSettingsContent),
Versioned(Box<VersionedAssistantSettingsContent>),
Legacy(LegacyAssistantSettingsContent),
}
@ -121,14 +135,14 @@ impl JsonSchema for AssistantSettingsContent {
impl Default for AssistantSettingsContent {
fn default() -> Self {
Self::Versioned(VersionedAssistantSettingsContent::default())
Self::Versioned(Box::new(VersionedAssistantSettingsContent::default()))
}
}
impl AssistantSettingsContent {
pub fn is_version_outdated(&self) -> bool {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(_) => true,
VersionedAssistantSettingsContent::V2(_) => false,
},
@ -138,8 +152,8 @@ impl AssistantSettingsContent {
fn upgrade(&self) -> AssistantSettingsContentV2 {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(ref settings) => AssistantSettingsContentV2 {
enabled: settings.enabled,
button: settings.button,
dock: settings.dock,
@ -186,7 +200,9 @@ impl AssistantSettingsContent {
})
}
}),
editor_model: None,
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
@ -194,7 +210,7 @@ impl AssistantSettingsContent {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
},
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
enabled: None,
@ -211,7 +227,9 @@ impl AssistantSettingsContent {
.id()
.to_string(),
}),
editor_model: None,
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
@ -224,11 +242,11 @@ impl AssistantSettingsContent {
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(ref mut settings) => {
settings.dock = Some(dock);
}
VersionedAssistantSettingsContent::V2(settings) => {
VersionedAssistantSettingsContent::V2(ref mut settings) => {
settings.dock = Some(dock);
}
},
@ -243,77 +261,79 @@ impl AssistantSettingsContent {
let provider = language_model.provider_id().0.to_string();
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
"zed.dev" => {
log::warn!("attempted to set zed.dev model on outdated settings");
}
"anthropic" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Anthropic {
default_model: AnthropicModel::from_id(&model).ok(),
api_url,
});
}
"ollama" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Ollama { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model, None, None)),
api_url,
});
}
"lmstudio" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::LmStudio {
default_model: Some(lmstudio::Model::new(&model, None, None)),
api_url,
});
}
"openai" => {
let (api_url, available_models) = match &settings.provider {
Some(AssistantProviderContentV1::OpenAi {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(ref mut settings) => {
match provider.as_ref() {
"zed.dev" => {
log::warn!("attempted to set zed.dev model on outdated settings");
}
"anthropic" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Anthropic {
default_model: AnthropicModel::from_id(&model).ok(),
api_url,
});
}
"ollama" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Ollama { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model, None, None)),
api_url,
});
}
"lmstudio" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::LmStudio {
default_model: Some(lmstudio::Model::new(&model, None, None)),
api_url,
});
}
"openai" => {
let (api_url, available_models) = match &settings.provider {
Some(AssistantProviderContentV1::OpenAi {
api_url,
available_models,
..
}) => (api_url.clone(), available_models.clone()),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::OpenAi {
default_model: OpenAiModel::from_id(&model).ok(),
api_url,
available_models,
..
}) => (api_url.clone(), available_models.clone()),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::OpenAi {
default_model: OpenAiModel::from_id(&model).ok(),
api_url,
available_models,
});
});
}
"deepseek" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::DeepSeek {
default_model: DeepseekModel::from_id(&model).ok(),
api_url,
});
}
_ => {}
}
"deepseek" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::DeepSeek {
default_model: DeepseekModel::from_id(&model).ok(),
api_url,
});
}
_ => {}
},
VersionedAssistantSettingsContent::V2(settings) => {
}
VersionedAssistantSettingsContent::V2(ref mut settings) => {
settings.default_model = Some(LanguageModelSelection { provider, model });
}
},
@ -325,23 +345,48 @@ impl AssistantSettingsContent {
}
}
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
if let AssistantSettingsContent::Versioned(boxed) = self {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.inline_assistant_model = Some(LanguageModelSelection { provider, model });
}
}
}
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
if let AssistantSettingsContent::Versioned(boxed) = self {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.commit_message_model = Some(LanguageModelSelection { provider, model });
}
}
}
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
if let AssistantSettingsContent::Versioned(boxed) = self {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.thread_summary_model = Some(LanguageModelSelection { provider, model });
}
}
}
pub fn set_always_allow_tool_actions(&mut self, allow: bool) {
let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
self
else {
let AssistantSettingsContent::Versioned(boxed) = self else {
return;
};
settings.always_allow_tool_actions = Some(allow);
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.always_allow_tool_actions = Some(allow);
}
}
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
self
else {
let AssistantSettingsContent::Versioned(boxed) = self else {
return;
};
settings.default_profile = Some(profile_id);
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.default_profile = Some(profile_id);
}
}
pub fn create_profile(
@ -349,37 +394,37 @@ impl AssistantSettingsContent {
profile_id: AgentProfileId,
profile: AgentProfile,
) -> Result<()> {
let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
self
else {
let AssistantSettingsContent::Versioned(boxed) = self else {
return Ok(());
};
let profiles = settings.profiles.get_or_insert_default();
if profiles.contains_key(&profile_id) {
bail!("profile with ID '{profile_id}' already exists");
}
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
let profiles = settings.profiles.get_or_insert_default();
if profiles.contains_key(&profile_id) {
bail!("profile with ID '{profile_id}' already exists");
}
profiles.insert(
profile_id,
AgentProfileContent {
name: profile.name.into(),
tools: profile.tools,
enable_all_context_servers: Some(profile.enable_all_context_servers),
context_servers: profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
},
);
profiles.insert(
profile_id,
AgentProfileContent {
name: profile.name.into(),
tools: profile.tools,
enable_all_context_servers: Some(profile.enable_all_context_servers),
context_servers: profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
},
);
}
Ok(())
}
@ -403,7 +448,9 @@ impl Default for VersionedAssistantSettingsContent {
default_width: None,
default_height: None,
default_model: None,
editor_model: None,
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
@ -436,10 +483,14 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: 320
default_height: Option<f32>,
/// The default model to use when creating new chats.
/// The default model to use when creating new chats and for other features when a specific model is not specified.
default_model: Option<LanguageModelSelection>,
/// The model to use when applying edits from the assistant.
editor_model: Option<LanguageModelSelection>,
/// Model to use for the inline assistant. Defaults to default_model when not specified.
inline_assistant_model: Option<LanguageModelSelection>,
/// Model to use for generating git commit messages. Defaults to default_model when not specified.
commit_message_model: Option<LanguageModelSelection>,
/// Model to use for generating thread summaries. Defaults to default_model when not specified.
thread_summary_model: Option<LanguageModelSelection>,
/// Additional models with which to generate alternatives when performing inline assists.
inline_alternatives: Option<Vec<LanguageModelSelection>>,
/// Enable experimental live diffs in the assistant panel.
@ -601,7 +652,15 @@ impl Settings for AssistantSettings {
value.default_height.map(Into::into),
);
merge(&mut settings.default_model, value.default_model);
merge(&mut settings.editor_model, value.editor_model);
settings.inline_assistant_model = value
.inline_assistant_model
.or(settings.inline_assistant_model.take());
settings.commit_message_model = value
.commit_message_model
.or(settings.commit_message_model.take());
settings.thread_summary_model = value
.thread_summary_model
.or(settings.thread_summary_model.take());
merge(&mut settings.inline_alternatives, value.inline_alternatives);
merge(
&mut settings.enable_experimental_live_diffs,
@ -692,16 +751,15 @@ mod tests {
settings::SettingsStore::global(cx).update_settings_file::<AssistantSettings>(
fs.clone(),
|settings, _| {
*settings = AssistantSettingsContent::Versioned(
*settings = AssistantSettingsContent::Versioned(Box::new(
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
default_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
editor_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enabled: None,
button: None,
@ -714,7 +772,7 @@ mod tests {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
}),
)
))
},
);
});

View file

@ -9,7 +9,7 @@ use collections::HashSet;
use edit_action::{EditAction, EditActionParser, edit_model_prompt};
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::LanguageModelToolSchemaFormat;
use language_model::{ConfiguredModel, LanguageModelToolSchemaFormat};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
@ -205,8 +205,8 @@ impl EditToolRequest {
cx: &mut App,
) -> Task<Result<String>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.editor_model() else {
return Task::ready(Err(anyhow!("No editor model configured")));
let Some(ConfiguredModel { model, .. }) = model_registry.default_model() else {
return Task::ready(Err(anyhow!("No model configured")));
};
let mut messages = messages.to_vec();

View file

@ -37,7 +37,8 @@ use gpui::{
use itertools::Itertools;
use language::{Buffer, File};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, Role,
};
use menu::{Confirm, SecondaryConfirm, SelectFirst, SelectLast, SelectNext, SelectPrevious};
use multi_buffer::ExcerptInfo;
@ -3764,8 +3765,9 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn Language
assistant_settings::AssistantSettings::get_global(cx)
.enabled
.then(|| {
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let ConfiguredModel { provider, model } =
LanguageModelRegistry::read_global(cx).commit_message_model()?;
provider.is_authenticated(cx).then(|| model)
})
.flatten()

View file

@ -17,20 +17,25 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
active_model: Option<ActiveModel>,
editor_model: Option<ActiveModel>,
default_model: Option<ConfiguredModel>,
inline_assistant_model: Option<ConfiguredModel>,
commit_message_model: Option<ConfiguredModel>,
thread_summary_model: Option<ConfiguredModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
}
pub struct ActiveModel {
provider: Arc<dyn LanguageModelProvider>,
model: Option<Arc<dyn LanguageModel>>,
#[derive(Clone)]
pub struct ConfiguredModel {
pub provider: Arc<dyn LanguageModelProvider>,
pub model: Arc<dyn LanguageModel>,
}
pub enum Event {
ActiveModelChanged,
EditorModelChanged,
DefaultModelChanged,
InlineAssistantModelChanged,
CommitMessageModelChanged,
ThreadSummaryModelChanged,
ProviderStateChanged,
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
@ -54,7 +59,7 @@ impl LanguageModelRegistry {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone();
registry.set_active_model(Some(model), cx);
registry.set_default_model(Some(model), cx);
registry
});
cx.set_global(GlobalLanguageModelRegistry(registry));
@ -114,7 +119,7 @@ impl LanguageModelRegistry {
self.providers.get(id).cloned()
}
pub fn select_active_model(
pub fn select_default_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
@ -126,11 +131,11 @@ impl LanguageModelRegistry {
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_active_model(Some(model), cx);
self.set_default_model(Some(model), cx);
}
}
pub fn select_editor_model(
pub fn select_inline_assistant_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
@ -142,23 +147,43 @@ impl LanguageModelRegistry {
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_editor_model(Some(model), cx);
self.set_inline_assistant_model(Some(model), cx);
}
}
pub fn set_active_provider(
pub fn select_commit_message_model(
&mut self,
provider: Option<Arc<dyn LanguageModelProvider>>,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
cx: &mut Context<Self>,
) {
self.active_model = provider.map(|provider| ActiveModel {
provider,
model: None,
});
cx.emit(Event::ActiveModelChanged);
let Some(provider) = self.provider(provider) else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_commit_message_model(Some(model), cx);
}
}
pub fn set_active_model(
pub fn select_thread_summary_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
cx: &mut Context<Self>,
) {
let Some(provider) = self.provider(provider) else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_thread_summary_model(Some(model), cx);
}
}
pub fn set_default_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
@ -166,21 +191,18 @@ impl LanguageModelRegistry {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.active_model = Some(ActiveModel {
provider,
model: Some(model),
});
cx.emit(Event::ActiveModelChanged);
self.default_model = Some(ConfiguredModel { provider, model });
cx.emit(Event::DefaultModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
}
} else {
self.active_model = None;
cx.emit(Event::ActiveModelChanged);
self.default_model = None;
cx.emit(Event::DefaultModelChanged);
}
}
pub fn set_editor_model(
pub fn set_inline_assistant_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
@ -188,35 +210,80 @@ impl LanguageModelRegistry {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.editor_model = Some(ActiveModel {
provider,
model: Some(model),
});
cx.emit(Event::EditorModelChanged);
self.inline_assistant_model = Some(ConfiguredModel { provider, model });
cx.emit(Event::InlineAssistantModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
log::warn!("Inline assistant model's provider not found in registry");
}
} else {
self.editor_model = None;
cx.emit(Event::EditorModelChanged);
self.inline_assistant_model = None;
cx.emit(Event::InlineAssistantModelChanged);
}
}
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
pub fn set_commit_message_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
) {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.commit_message_model = Some(ConfiguredModel { provider, model });
cx.emit(Event::CommitMessageModelChanged);
} else {
log::warn!("Commit message model's provider not found in registry");
}
} else {
self.commit_message_model = None;
cx.emit(Event::CommitMessageModelChanged);
}
}
pub fn set_thread_summary_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
) {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.thread_summary_model = Some(ConfiguredModel { provider, model });
cx.emit(Event::ThreadSummaryModelChanged);
} else {
log::warn!("Thread summary model's provider not found in registry");
}
} else {
self.thread_summary_model = None;
cx.emit(Event::ThreadSummaryModelChanged);
}
}
pub fn default_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
Some(self.active_model.as_ref()?.provider.clone())
self.default_model.clone()
}
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.active_model.as_ref()?.model.clone()
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
self.inline_assistant_model
.clone()
.or_else(|| self.default_model())
}
pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.editor_model.as_ref()?.model.clone()
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
self.commit_message_model
.clone()
.or_else(|| self.default_model())
}
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
self.thread_summary_model
.clone()
.or_else(|| self.default_model())
}
/// Selects and sets the inline alternatives for language models based on

View file

@ -168,11 +168,11 @@ impl LanguageModelSelector {
}
fn get_active_model_index(cx: &App) -> usize {
let active_model = LanguageModelRegistry::read_global(cx).active_model();
let active_model = LanguageModelRegistry::read_global(cx).default_model();
Self::all_models(cx)
.iter()
.position(|model_info| {
Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id())
})
.unwrap_or(0)
}
@ -406,13 +406,10 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let model_info = self.filtered_models.get(ix)?;
let provider_name: String = model_info.model.provider_name().0.clone().into();
let active_provider_id = LanguageModelRegistry::read_global(cx)
.active_provider()
.map(|m| m.id());
let active_model = LanguageModelRegistry::read_global(cx).default_model();
let active_model_id = LanguageModelRegistry::read_global(cx)
.active_model()
.map(|m| m.id());
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
let active_model_id = active_model.map(|m| m.model.id());
let is_selected = Some(model_info.model.provider_id()) == active_provider_id
&& Some(model_info.model.id()) == active_model_id;

View file

@ -9,7 +9,7 @@ use gpui::{
};
use language::{Buffer, LanguageRegistry, language_settings::SoftWrap};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use picker::{Picker, PickerDelegate};
use release_channel::ReleaseChannel;
@ -777,7 +777,9 @@ impl PromptLibrary {
};
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
let Some(ConfiguredModel { provider, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};
@ -880,7 +882,9 @@ impl PromptLibrary {
}
fn count_tokens(&mut self, prompt_id: PromptId, window: &mut Window, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).default_model()
else {
return;
};
if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) {
@ -967,7 +971,9 @@ impl PromptLibrary {
let prompt_metadata = self.store.metadata(prompt_id)?;
let prompt_editor = &self.prompt_editors[&prompt_id];
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
let model = LanguageModelRegistry::read_global(cx).active_model();
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model);
let settings = ThemeSettings::get_global(cx);
Some(

View file

@ -19,7 +19,8 @@ To further customize providers, you can use `settings.json` to do that as follow
- [Configuring endpoints](#custom-endpoint)
- [Configuring timeouts](#provider-timeout)
- [Configuring default model](#default-model)
- [Configuring models](#default-model)
- [Configuring feature-specific models](#feature-specific-models)
- [Configuring alternative models for inline assists](#alternative-assists)
### Zed AI {#zed-ai}
@ -281,8 +282,24 @@ Example configuration for using X.ai Grok with Zed:
"enabled": true,
"default_model": {
"provider": "zed.dev",
"model": "claude-3-7-sonnet"
},
"editor_model": {
"provider": "openai",
"model": "gpt-4o"
},
"inline_assistant_model": {
"provider": "anthropic",
"model": "claude-3-5-sonnet"
},
"commit_message_model": {
"provider": "openai",
"model": "gpt-4o-mini"
},
"thread_summary_model": {
"provider": "google",
"model": "gemini-1.5-flash"
},
"version": "2",
"button": true,
"default_width": 480,
@ -328,7 +345,7 @@ To do so, add the following to your Zed `settings.json`:
Where `some-provider` can be any of the following values: `anthropic`, `google`, `ollama`, `openai`.
#### Configuring the default model {#default-model}
#### Configuring models {#default-model}
The default model can be set via the model dropdown in the assistant panel's top-right corner. Selecting a model saves it as the default.
You can also manually edit the `default_model` object in your settings:
@ -345,6 +362,47 @@ You can also manually edit the `default_model` object in your settings:
}
```
#### Feature-specific models {#feature-specific-models}
> Currently only available in [Preview](https://zed.dev/releases/preview).
Zed allows you to configure different models for specific features.
This provides flexibility to use more powerful models for certain tasks while using faster or more efficient models for others.
If a feature-specific model is not set, it will fall back to using the default model, which is the one you set on the Agent Panel.
You can configure the following feature-specific models:
- Thread summary model: Used for generating thread summaries
- Inline assistant model: Used for the inline assistant feature
- Commit message model: Used for generating Git commit messages
Example configuration:
```json
{
"assistant": {
"version": "2",
"default_model": {
"provider": "zed.dev",
"model": "claude-3-7-sonnet"
},
"inline_assistant_model": {
"provider": "anthropic",
"model": "claude-3-5-sonnet"
},
"commit_message_model": {
"provider": "openai",
"model": "gpt-4o-mini"
},
"thread_summary_model": {
"provider": "google",
"model": "gemini-2.0-flash"
}
}
}
```
#### Configuring alternative models for inline assists {#alternative-assists}
You can configure additional models that will be used to perform inline assists in parallel. When you do this,