Associate each thread with a model (#29573)

This PR makes it possible to use different LLM models in the agent
panels of two different projects, simultaneously. It also properly
restores a thread's original model when restoring it from the history,
rather than having it use the default model. As before, newly-created
threads will use the current default model.

Release Notes:

- Enabled different project windows to use different models in the agent
panel
- Enhanced the agent panel so that when revisiting old threads, their
original model will be used.

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Max Brunsfeld 2025-04-28 16:43:16 -07:00 committed by GitHub
parent 5102c4c002
commit 17903a0999
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 168 additions and 114 deletions

View file

@ -22,7 +22,8 @@ action_with_deprecated_aliases!(
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
pub struct LanguageModelSelector {
picker: Entity<Picker<LanguageModelPickerDelegate>>,
@ -30,16 +31,10 @@ pub struct LanguageModelSelector {
_subscriptions: Vec<Subscription>,
}
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
impl LanguageModelSelector {
pub fn new(
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
model_type: ModelType,
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@ -52,9 +47,9 @@ impl LanguageModelSelector {
language_model_selector: cx.entity().downgrade(),
on_model_changed: on_model_changed.clone(),
all_models: Arc::new(all_models),
selected_index: Self::get_active_model_index(&entries, model_type, cx),
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries,
model_type,
get_active_model: Arc::new(get_active_model),
};
let picker = cx.new(|cx| {
@ -204,26 +199,13 @@ impl LanguageModelSelector {
}
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
let model_type = self.picker.read(cx).delegate.model_type;
Self::active_model_by_type(model_type, cx)
}
fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
match model_type {
ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
ModelType::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
}
(self.picker.read(cx).delegate.get_active_model)(cx)
}
fn get_active_model_index(
entries: &[LanguageModelPickerEntry],
model_type: ModelType,
cx: &App,
active_model: Option<ConfiguredModel>,
) -> usize {
let active_model = Self::active_model_by_type(model_type, cx);
entries
.iter()
.position(|entry| {
@ -232,7 +214,7 @@ impl LanguageModelSelector {
.as_ref()
.map(|active_model| {
active_model.model.id() == model.model.id()
&& active_model.model.provider_id() == model.model.provider_id()
&& active_model.provider.id() == model.model.provider_id()
})
.unwrap_or_default()
} else {
@ -325,10 +307,10 @@ struct ModelInfo {
pub struct LanguageModelPickerDelegate {
language_model_selector: WeakEntity<LanguageModelSelector>,
on_model_changed: OnModelChanged,
get_active_model: GetActiveModel,
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
model_type: ModelType,
}
struct GroupedModels {
@ -522,8 +504,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.into_any_element(),
),
LanguageModelPickerEntry::Model(model_info) => {
let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
let active_model = (self.get_active_model)(cx);
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
let active_model_id = active_model.map(|m| m.model.id());