language_model_selector: Refresh the models when the providers change (#22624)
This PR fixes an issue introduced in #21939 where the list of models in the language model selector could be outdated. Since we're no longer recreating the picker each render, we now need to make sure we are updating the list of models accordingly when there are changes to the language model providers. I noticed it specifically in Assistant1. Release Notes: - Fixed a staleness issue with the language model selector.
This commit is contained in:
parent
e4eef725de
commit
04518b11bc
1 changed files with 56 additions and 21 deletions
|
@ -2,8 +2,8 @@ use std::sync::Arc;
|
|||
|
||||
use feature_flags::ZedPro;
|
||||
use gpui::{
|
||||
Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task,
|
||||
View, WeakView,
|
||||
Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model,
|
||||
Subscription, Task, View, WeakView,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
|
@ -17,6 +17,10 @@ type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>
|
|||
|
||||
pub struct LanguageModelSelector {
|
||||
picker: View<Picker<LanguageModelPickerDelegate>>,
|
||||
/// The task used to update the picker's matches when there is a change to
|
||||
/// the language model registry.
|
||||
update_matches_task: Option<Task<()>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
impl LanguageModelSelector {
|
||||
|
@ -26,7 +30,51 @@ impl LanguageModelSelector {
|
|||
) -> Self {
|
||||
let on_model_changed = Arc::new(on_model_changed);
|
||||
|
||||
let all_models = LanguageModelRegistry::global(cx)
|
||||
let all_models = Self::all_models(cx);
|
||||
let delegate = LanguageModelPickerDelegate {
|
||||
language_model_selector: cx.view().downgrade(),
|
||||
on_model_changed: on_model_changed.clone(),
|
||||
all_models: all_models.clone(),
|
||||
filtered_models: all_models,
|
||||
selected_index: 0,
|
||||
};
|
||||
|
||||
let picker =
|
||||
cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
|
||||
|
||||
LanguageModelSelector {
|
||||
picker,
|
||||
update_matches_task: None,
|
||||
_subscriptions: vec![cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
Self::handle_language_model_registry_event,
|
||||
)],
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_language_model_registry_event(
|
||||
&mut self,
|
||||
_registry: Model<LanguageModelRegistry>,
|
||||
event: &language_model::Event,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
match event {
|
||||
language_model::Event::ProviderStateChanged
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
let task = self.picker.update(cx, |this, cx| {
|
||||
let query = this.query(cx);
|
||||
this.delegate.all_models = Self::all_models(cx);
|
||||
this.delegate.update_matches(query, cx)
|
||||
});
|
||||
self.update_matches_task = Some(task);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn all_models(cx: &AppContext) -> Vec<ModelInfo> {
|
||||
LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
|
@ -44,20 +92,7 @@ impl LanguageModelSelector {
|
|||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let delegate = LanguageModelPickerDelegate {
|
||||
language_model_selector: cx.view().downgrade(),
|
||||
on_model_changed: on_model_changed.clone(),
|
||||
all_models: all_models.clone(),
|
||||
filtered_models: all_models,
|
||||
selected_index: 0,
|
||||
};
|
||||
|
||||
let picker =
|
||||
cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
|
||||
|
||||
LanguageModelSelector { picker }
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -152,25 +187,25 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
|
||||
let llm_registry = LanguageModelRegistry::global(cx);
|
||||
|
||||
let configured_models: Vec<_> = llm_registry
|
||||
let configured_providers = llm_registry
|
||||
.read(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.filter(|provider| provider.is_authenticated(cx))
|
||||
.map(|provider| provider.id())
|
||||
.collect();
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let filtered_models = cx
|
||||
.background_executor()
|
||||
.spawn(async move {
|
||||
let displayed_models = if configured_models.is_empty() {
|
||||
let displayed_models = if configured_providers.is_empty() {
|
||||
all_models
|
||||
} else {
|
||||
all_models
|
||||
.into_iter()
|
||||
.filter(|model_info| {
|
||||
configured_models.contains(&model_info.model.provider_id())
|
||||
configured_providers.contains(&model_info.model.provider_id())
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue