ai: Separate model settings for each feature (#28088)

Closes: https://github.com/zed-industries/zed/issues/20582

Allows users to select a specific model for each AI-powered feature:
- Agent panel
- Inline assistant
- Thread summarization
- Commit message generation

If unspecified for a given feature, it will use the `default_model`
setting.

Release Notes:

- Added support for configuring a specific model for each AI-powered
feature

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
Agus Zubiaga 2025-04-04 11:40:55 -03:00 committed by GitHub
parent cf0d1e4229
commit 43cb925a59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 670 additions and 381 deletions

View file

@ -17,20 +17,25 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
active_model: Option<ActiveModel>,
editor_model: Option<ActiveModel>,
default_model: Option<ConfiguredModel>,
inline_assistant_model: Option<ConfiguredModel>,
commit_message_model: Option<ConfiguredModel>,
thread_summary_model: Option<ConfiguredModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
}
pub struct ActiveModel {
provider: Arc<dyn LanguageModelProvider>,
model: Option<Arc<dyn LanguageModel>>,
#[derive(Clone)]
pub struct ConfiguredModel {
pub provider: Arc<dyn LanguageModelProvider>,
pub model: Arc<dyn LanguageModel>,
}
pub enum Event {
ActiveModelChanged,
EditorModelChanged,
DefaultModelChanged,
InlineAssistantModelChanged,
CommitMessageModelChanged,
ThreadSummaryModelChanged,
ProviderStateChanged,
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
@ -54,7 +59,7 @@ impl LanguageModelRegistry {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone();
registry.set_active_model(Some(model), cx);
registry.set_default_model(Some(model), cx);
registry
});
cx.set_global(GlobalLanguageModelRegistry(registry));
@ -114,7 +119,7 @@ impl LanguageModelRegistry {
self.providers.get(id).cloned()
}
pub fn select_active_model(
pub fn select_default_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
@ -126,11 +131,11 @@ impl LanguageModelRegistry {
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_active_model(Some(model), cx);
self.set_default_model(Some(model), cx);
}
}
pub fn select_editor_model(
pub fn select_inline_assistant_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
@ -142,23 +147,43 @@ impl LanguageModelRegistry {
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_editor_model(Some(model), cx);
self.set_inline_assistant_model(Some(model), cx);
}
}
pub fn set_active_provider(
pub fn select_commit_message_model(
&mut self,
provider: Option<Arc<dyn LanguageModelProvider>>,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
cx: &mut Context<Self>,
) {
self.active_model = provider.map(|provider| ActiveModel {
provider,
model: None,
});
cx.emit(Event::ActiveModelChanged);
let Some(provider) = self.provider(provider) else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_commit_message_model(Some(model), cx);
}
}
pub fn set_active_model(
pub fn select_thread_summary_model(
&mut self,
provider: &LanguageModelProviderId,
model_id: &LanguageModelId,
cx: &mut Context<Self>,
) {
let Some(provider) = self.provider(provider) else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_thread_summary_model(Some(model), cx);
}
}
pub fn set_default_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
@ -166,21 +191,18 @@ impl LanguageModelRegistry {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.active_model = Some(ActiveModel {
provider,
model: Some(model),
});
cx.emit(Event::ActiveModelChanged);
self.default_model = Some(ConfiguredModel { provider, model });
cx.emit(Event::DefaultModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
}
} else {
self.active_model = None;
cx.emit(Event::ActiveModelChanged);
self.default_model = None;
cx.emit(Event::DefaultModelChanged);
}
}
pub fn set_editor_model(
pub fn set_inline_assistant_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
@ -188,35 +210,80 @@ impl LanguageModelRegistry {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.editor_model = Some(ActiveModel {
provider,
model: Some(model),
});
cx.emit(Event::EditorModelChanged);
self.inline_assistant_model = Some(ConfiguredModel { provider, model });
cx.emit(Event::InlineAssistantModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
log::warn!("Inline assistant model's provider not found in registry");
}
} else {
self.editor_model = None;
cx.emit(Event::EditorModelChanged);
self.inline_assistant_model = None;
cx.emit(Event::InlineAssistantModelChanged);
}
}
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
pub fn set_commit_message_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
) {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.commit_message_model = Some(ConfiguredModel { provider, model });
cx.emit(Event::CommitMessageModelChanged);
} else {
log::warn!("Commit message model's provider not found in registry");
}
} else {
self.commit_message_model = None;
cx.emit(Event::CommitMessageModelChanged);
}
}
pub fn set_thread_summary_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
) {
if let Some(model) = model {
let provider_id = model.provider_id();
if let Some(provider) = self.providers.get(&provider_id).cloned() {
self.thread_summary_model = Some(ConfiguredModel { provider, model });
cx.emit(Event::ThreadSummaryModelChanged);
} else {
log::warn!("Thread summary model's provider not found in registry");
}
} else {
self.thread_summary_model = None;
cx.emit(Event::ThreadSummaryModelChanged);
}
}
pub fn default_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
Some(self.active_model.as_ref()?.provider.clone())
self.default_model.clone()
}
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.active_model.as_ref()?.model.clone()
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
self.inline_assistant_model
.clone()
.or_else(|| self.default_model())
}
pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.editor_model.as_ref()?.model.clone()
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
self.commit_message_model
.clone()
.or_else(|| self.default_model())
}
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
self.thread_summary_model
.clone()
.or_else(|| self.default_model())
}
/// Selects and sets the inline alternatives for language models based on