ai: Auto select user model when there's no default (#36722)
This PR identifies automatic configuration options that users can select from the agent panel. If no default provider is set in their settings, the PR defaults to the first recommended option. Additionally, it updates the selected provider for a thread when a user changes the default provider through the settings file, if the thread hasn't had any queries yet. Release Notes: - agent: automatically select a language model provider if there's no user set provider. --------- Co-authored-by: Michael Sloan <michael@zed.dev>
This commit is contained in:
parent
e15856a37f
commit
b349a8f34c
9 changed files with 184 additions and 122 deletions
|
@ -6,7 +6,6 @@ use collections::BTreeMap;
|
|||
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use util::maybe;
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let registry = cx.new(|_cx| LanguageModelRegistry::default());
|
||||
|
@ -48,7 +47,9 @@ impl std::fmt::Debug for ConfigurationError {
|
|||
#[derive(Default)]
|
||||
pub struct LanguageModelRegistry {
|
||||
default_model: Option<ConfiguredModel>,
|
||||
default_fast_model: Option<ConfiguredModel>,
|
||||
/// This model is automatically configured by a user's environment after
|
||||
/// authenticating all providers. It's only used when default_model is not available.
|
||||
environment_fallback_model: Option<ConfiguredModel>,
|
||||
inline_assistant_model: Option<ConfiguredModel>,
|
||||
commit_message_model: Option<ConfiguredModel>,
|
||||
thread_summary_model: Option<ConfiguredModel>,
|
||||
|
@ -104,9 +105,6 @@ impl ConfiguredModel {
|
|||
|
||||
pub enum Event {
|
||||
DefaultModelChanged,
|
||||
InlineAssistantModelChanged,
|
||||
CommitMessageModelChanged,
|
||||
ThreadSummaryModelChanged,
|
||||
ProviderStateChanged(LanguageModelProviderId),
|
||||
AddedProvider(LanguageModelProviderId),
|
||||
RemovedProvider(LanguageModelProviderId),
|
||||
|
@ -238,7 +236,7 @@ impl LanguageModelRegistry {
|
|||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let configured_model = model.and_then(|model| self.select_model(model, cx));
|
||||
self.set_inline_assistant_model(configured_model, cx);
|
||||
self.set_inline_assistant_model(configured_model);
|
||||
}
|
||||
|
||||
pub fn select_commit_message_model(
|
||||
|
@ -247,7 +245,7 @@ impl LanguageModelRegistry {
|
|||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let configured_model = model.and_then(|model| self.select_model(model, cx));
|
||||
self.set_commit_message_model(configured_model, cx);
|
||||
self.set_commit_message_model(configured_model);
|
||||
}
|
||||
|
||||
pub fn select_thread_summary_model(
|
||||
|
@ -256,7 +254,7 @@ impl LanguageModelRegistry {
|
|||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let configured_model = model.and_then(|model| self.select_model(model, cx));
|
||||
self.set_thread_summary_model(configured_model, cx);
|
||||
self.set_thread_summary_model(configured_model);
|
||||
}
|
||||
|
||||
/// Selects and sets the inline alternatives for language models based on
|
||||
|
@ -290,68 +288,60 @@ impl LanguageModelRegistry {
|
|||
}
|
||||
|
||||
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
|
||||
match (self.default_model.as_ref(), model.as_ref()) {
|
||||
match (self.default_model(), model.as_ref()) {
|
||||
(Some(old), Some(new)) if old.is_same_as(new) => {}
|
||||
(None, None) => {}
|
||||
_ => cx.emit(Event::DefaultModelChanged),
|
||||
}
|
||||
self.default_fast_model = maybe!({
|
||||
let provider = &model.as_ref()?.provider;
|
||||
let fast_model = provider.default_fast_model(cx)?;
|
||||
Some(ConfiguredModel {
|
||||
provider: provider.clone(),
|
||||
model: fast_model,
|
||||
})
|
||||
});
|
||||
self.default_model = model;
|
||||
}
|
||||
|
||||
pub fn set_inline_assistant_model(
|
||||
pub fn set_environment_fallback_model(
|
||||
&mut self,
|
||||
model: Option<ConfiguredModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match (self.inline_assistant_model.as_ref(), model.as_ref()) {
|
||||
(Some(old), Some(new)) if old.is_same_as(new) => {}
|
||||
(None, None) => {}
|
||||
_ => cx.emit(Event::InlineAssistantModelChanged),
|
||||
if self.default_model.is_none() {
|
||||
match (self.environment_fallback_model.as_ref(), model.as_ref()) {
|
||||
(Some(old), Some(new)) if old.is_same_as(new) => {}
|
||||
(None, None) => {}
|
||||
_ => cx.emit(Event::DefaultModelChanged),
|
||||
}
|
||||
}
|
||||
self.environment_fallback_model = model;
|
||||
}
|
||||
|
||||
pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
|
||||
self.inline_assistant_model = model;
|
||||
}
|
||||
|
||||
pub fn set_commit_message_model(
|
||||
&mut self,
|
||||
model: Option<ConfiguredModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match (self.commit_message_model.as_ref(), model.as_ref()) {
|
||||
(Some(old), Some(new)) if old.is_same_as(new) => {}
|
||||
(None, None) => {}
|
||||
_ => cx.emit(Event::CommitMessageModelChanged),
|
||||
}
|
||||
pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
|
||||
self.commit_message_model = model;
|
||||
}
|
||||
|
||||
pub fn set_thread_summary_model(
|
||||
&mut self,
|
||||
model: Option<ConfiguredModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match (self.thread_summary_model.as_ref(), model.as_ref()) {
|
||||
(Some(old), Some(new)) if old.is_same_as(new) => {}
|
||||
(None, None) => {}
|
||||
_ => cx.emit(Event::ThreadSummaryModelChanged),
|
||||
}
|
||||
pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
|
||||
self.thread_summary_model = model;
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn default_model(&self) -> Option<ConfiguredModel> {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.default_model.clone()
|
||||
self.default_model
|
||||
.clone()
|
||||
.or_else(|| self.environment_fallback_model.clone())
|
||||
}
|
||||
|
||||
pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
|
||||
let provider = self.default_model()?.provider;
|
||||
let fast_model = provider.default_fast_model(cx)?;
|
||||
Some(ConfiguredModel {
|
||||
provider,
|
||||
model: fast_model,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
|
||||
|
@ -365,7 +355,7 @@ impl LanguageModelRegistry {
|
|||
.or_else(|| self.default_model.clone())
|
||||
}
|
||||
|
||||
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
|
||||
pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
|
||||
return None;
|
||||
|
@ -373,11 +363,11 @@ impl LanguageModelRegistry {
|
|||
|
||||
self.commit_message_model
|
||||
.clone()
|
||||
.or_else(|| self.default_fast_model.clone())
|
||||
.or_else(|| self.default_fast_model(cx))
|
||||
.or_else(|| self.default_model.clone())
|
||||
}
|
||||
|
||||
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
|
||||
pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
|
||||
return None;
|
||||
|
@ -385,7 +375,7 @@ impl LanguageModelRegistry {
|
|||
|
||||
self.thread_summary_model
|
||||
.clone()
|
||||
.or_else(|| self.default_fast_model.clone())
|
||||
.or_else(|| self.default_fast_model(cx))
|
||||
.or_else(|| self.default_model.clone())
|
||||
}
|
||||
|
||||
|
@ -422,4 +412,34 @@ mod tests {
|
|||
let providers = registry.read(cx).providers();
|
||||
assert!(providers.is_empty());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
|
||||
let registry = cx.new(|_| LanguageModelRegistry::default());
|
||||
|
||||
let provider = FakeLanguageModelProvider::default();
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.register_provider(provider.clone(), cx);
|
||||
});
|
||||
|
||||
cx.update(|cx| provider.authenticate(cx)).await.unwrap();
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
let provider = registry.provider(&provider.id()).unwrap();
|
||||
|
||||
registry.set_environment_fallback_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: provider.clone(),
|
||||
model: provider.default_model(cx).unwrap(),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
|
||||
let default_model = registry.default_model().unwrap();
|
||||
let fallback_model = registry.environment_fallback_model.clone().unwrap();
|
||||
|
||||
assert_eq!(default_model.model.id(), fallback_model.model.id());
|
||||
assert_eq!(default_model.provider.id(), fallback_model.provider.id());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue