language_model_selector: Refresh the models when the providers change (#22624)

This PR fixes an issue introduced in #21939 where the list of models in
the language model selector could be outdated.

Since we're no longer recreating the picker each render, we now need to
make sure we are updating the list of models accordingly when there are
changes to the language model providers.

I noticed it specifically in Assistant1.

Release Notes:

- Fixed a staleness issue with the language model selector.
This commit is contained in:
Marshall Bowers 2025-01-03 14:38:08 -05:00 committed by GitHub
parent e4eef725de
commit 04518b11bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2,8 +2,8 @@ use std::sync::Arc;
use feature_flags::ZedPro;
use gpui::{
Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task,
View, WeakView,
Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model,
Subscription, Task, View, WeakView,
};
use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
use picker::{Picker, PickerDelegate};
@ -17,6 +17,10 @@ type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>
pub struct LanguageModelSelector {
picker: View<Picker<LanguageModelPickerDelegate>>,
/// The task used to update the picker's matches when there is a change to
/// the language model registry.
update_matches_task: Option<Task<()>>,
_subscriptions: Vec<Subscription>,
}
impl LanguageModelSelector {
@ -26,7 +30,51 @@ impl LanguageModelSelector {
) -> Self {
let on_model_changed = Arc::new(on_model_changed);
let all_models = LanguageModelRegistry::global(cx)
let all_models = Self::all_models(cx);
let delegate = LanguageModelPickerDelegate {
language_model_selector: cx.view().downgrade(),
on_model_changed: on_model_changed.clone(),
all_models: all_models.clone(),
filtered_models: all_models,
selected_index: 0,
};
let picker =
cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
LanguageModelSelector {
picker,
update_matches_task: None,
_subscriptions: vec![cx.subscribe(
&LanguageModelRegistry::global(cx),
Self::handle_language_model_registry_event,
)],
}
}
fn handle_language_model_registry_event(
&mut self,
_registry: Model<LanguageModelRegistry>,
event: &language_model::Event,
cx: &mut ViewContext<Self>,
) {
match event {
language_model::Event::ProviderStateChanged
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
let task = self.picker.update(cx, |this, cx| {
let query = this.query(cx);
this.delegate.all_models = Self::all_models(cx);
this.delegate.update_matches(query, cx)
});
self.update_matches_task = Some(task);
}
_ => {}
}
}
fn all_models(cx: &AppContext) -> Vec<ModelInfo> {
LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
@ -44,20 +92,7 @@ impl LanguageModelSelector {
}
})
})
.collect::<Vec<_>>();
let delegate = LanguageModelPickerDelegate {
language_model_selector: cx.view().downgrade(),
on_model_changed: on_model_changed.clone(),
all_models: all_models.clone(),
filtered_models: all_models,
selected_index: 0,
};
let picker =
cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
LanguageModelSelector { picker }
.collect::<Vec<_>>()
}
}
@ -152,25 +187,25 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let llm_registry = LanguageModelRegistry::global(cx);
let configured_models: Vec<_> = llm_registry
let configured_providers = llm_registry
.read(cx)
.providers()
.iter()
.filter(|provider| provider.is_authenticated(cx))
.map(|provider| provider.id())
.collect();
.collect::<Vec<_>>();
cx.spawn(|this, mut cx| async move {
let filtered_models = cx
.background_executor()
.spawn(async move {
let displayed_models = if configured_models.is_empty() {
let displayed_models = if configured_providers.is_empty() {
all_models
} else {
all_models
.into_iter()
.filter(|model_info| {
configured_models.contains(&model_info.model.provider_id())
configured_providers.contains(&model_info.model.provider_id())
})
.collect::<Vec<_>>()
};