ai: Separate model settings for each feature (#28088)
Closes: https://github.com/zed-industries/zed/issues/20582 Allows users to select a specific model for each AI-powered feature: - Agent panel - Inline assistant - Thread summarization - Commit message generation If unspecified for a given feature, it will use the `default_model` setting. Release Notes: - Added support for configuring a specific model for each AI-powered feature --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
parent
cf0d1e4229
commit
43cb925a59
27 changed files with 670 additions and 381 deletions
|
@ -17,20 +17,25 @@ impl Global for GlobalLanguageModelRegistry {}
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct LanguageModelRegistry {
|
||||
active_model: Option<ActiveModel>,
|
||||
editor_model: Option<ActiveModel>,
|
||||
default_model: Option<ConfiguredModel>,
|
||||
inline_assistant_model: Option<ConfiguredModel>,
|
||||
commit_message_model: Option<ConfiguredModel>,
|
||||
thread_summary_model: Option<ConfiguredModel>,
|
||||
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
|
||||
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
|
||||
}
|
||||
|
||||
pub struct ActiveModel {
|
||||
provider: Arc<dyn LanguageModelProvider>,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
#[derive(Clone)]
|
||||
pub struct ConfiguredModel {
|
||||
pub provider: Arc<dyn LanguageModelProvider>,
|
||||
pub model: Arc<dyn LanguageModel>,
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
ActiveModelChanged,
|
||||
EditorModelChanged,
|
||||
DefaultModelChanged,
|
||||
InlineAssistantModelChanged,
|
||||
CommitMessageModelChanged,
|
||||
ThreadSummaryModelChanged,
|
||||
ProviderStateChanged,
|
||||
AddedProvider(LanguageModelProviderId),
|
||||
RemovedProvider(LanguageModelProviderId),
|
||||
|
@ -54,7 +59,7 @@ impl LanguageModelRegistry {
|
|||
let mut registry = Self::default();
|
||||
registry.register_provider(fake_provider.clone(), cx);
|
||||
let model = fake_provider.provided_models(cx)[0].clone();
|
||||
registry.set_active_model(Some(model), cx);
|
||||
registry.set_default_model(Some(model), cx);
|
||||
registry
|
||||
});
|
||||
cx.set_global(GlobalLanguageModelRegistry(registry));
|
||||
|
@ -114,7 +119,7 @@ impl LanguageModelRegistry {
|
|||
self.providers.get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn select_active_model(
|
||||
pub fn select_default_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
|
@ -126,11 +131,11 @@ impl LanguageModelRegistry {
|
|||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_active_model(Some(model), cx);
|
||||
self.set_default_model(Some(model), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_editor_model(
|
||||
pub fn select_inline_assistant_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
|
@ -142,23 +147,43 @@ impl LanguageModelRegistry {
|
|||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_editor_model(Some(model), cx);
|
||||
self.set_inline_assistant_model(Some(model), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_active_provider(
|
||||
pub fn select_commit_message_model(
|
||||
&mut self,
|
||||
provider: Option<Arc<dyn LanguageModelProvider>>,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.active_model = provider.map(|provider| ActiveModel {
|
||||
provider,
|
||||
model: None,
|
||||
});
|
||||
cx.emit(Event::ActiveModelChanged);
|
||||
let Some(provider) = self.provider(provider) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_commit_message_model(Some(model), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_active_model(
|
||||
pub fn select_thread_summary_model(
|
||||
&mut self,
|
||||
provider: &LanguageModelProviderId,
|
||||
model_id: &LanguageModelId,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(provider) = self.provider(provider) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let models = provider.provided_models(cx);
|
||||
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
|
||||
self.set_thread_summary_model(Some(model), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_default_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -166,21 +191,18 @@ impl LanguageModelRegistry {
|
|||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.active_model = Some(ActiveModel {
|
||||
provider,
|
||||
model: Some(model),
|
||||
});
|
||||
cx.emit(Event::ActiveModelChanged);
|
||||
self.default_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::DefaultModelChanged);
|
||||
} else {
|
||||
log::warn!("Active model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.active_model = None;
|
||||
cx.emit(Event::ActiveModelChanged);
|
||||
self.default_model = None;
|
||||
cx.emit(Event::DefaultModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_editor_model(
|
||||
pub fn set_inline_assistant_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -188,35 +210,80 @@ impl LanguageModelRegistry {
|
|||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.editor_model = Some(ActiveModel {
|
||||
provider,
|
||||
model: Some(model),
|
||||
});
|
||||
cx.emit(Event::EditorModelChanged);
|
||||
self.inline_assistant_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::InlineAssistantModelChanged);
|
||||
} else {
|
||||
log::warn!("Active model's provider not found in registry");
|
||||
log::warn!("Inline assistant model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.editor_model = None;
|
||||
cx.emit(Event::EditorModelChanged);
|
||||
self.inline_assistant_model = None;
|
||||
cx.emit(Event::InlineAssistantModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||
pub fn set_commit_message_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.commit_message_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::CommitMessageModelChanged);
|
||||
} else {
|
||||
log::warn!("Commit message model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.commit_message_model = None;
|
||||
cx.emit(Event::CommitMessageModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_thread_summary_model(
|
||||
&mut self,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(model) = model {
|
||||
let provider_id = model.provider_id();
|
||||
if let Some(provider) = self.providers.get(&provider_id).cloned() {
|
||||
self.thread_summary_model = Some(ConfiguredModel { provider, model });
|
||||
cx.emit(Event::ThreadSummaryModelChanged);
|
||||
} else {
|
||||
log::warn!("Thread summary model's provider not found in registry");
|
||||
}
|
||||
} else {
|
||||
self.thread_summary_model = None;
|
||||
cx.emit(Event::ThreadSummaryModelChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> Option<ConfiguredModel> {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(self.active_model.as_ref()?.provider.clone())
|
||||
self.default_model.clone()
|
||||
}
|
||||
|
||||
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.active_model.as_ref()?.model.clone()
|
||||
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
|
||||
self.inline_assistant_model
|
||||
.clone()
|
||||
.or_else(|| self.default_model())
|
||||
}
|
||||
|
||||
pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.editor_model.as_ref()?.model.clone()
|
||||
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
|
||||
self.commit_message_model
|
||||
.clone()
|
||||
.or_else(|| self.default_model())
|
||||
}
|
||||
|
||||
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
|
||||
self.thread_summary_model
|
||||
.clone()
|
||||
.or_else(|| self.default_model())
|
||||
}
|
||||
|
||||
/// Selects and sets the inline alternatives for language models based on
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue