diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9e49760394..00c79c4459 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -2801,21 +2801,19 @@ pub struct ContextEditorToolbarItem { fs: Arc, workspace: WeakView, active_context_editor: Option>, - model_selector_menu_handle: PopoverMenuHandle, model_summary_editor: View, } impl ContextEditorToolbarItem { pub fn new( workspace: &Workspace, - model_selector_menu_handle: PopoverMenuHandle, + _model_selector_menu_handle: PopoverMenuHandle, model_summary_editor: View, ) -> Self { Self { fs: workspace.app_state().fs.clone(), workspace: workspace.weak_handle(), active_context_editor: None, - model_selector_menu_handle, model_summary_editor, } } @@ -2946,49 +2944,46 @@ impl Render for ContextEditorToolbarItem { }); let right_side = h_flex() .gap_2() - .child( - ModelSelector::new( - self.fs.clone(), - ButtonLike::new("active-model") - .style(ButtonStyle::Subtle) - .child( - h_flex() - .w_full() - .gap_0p5() - .child( - div() - .overflow_x_hidden() - .flex_grow() - .whitespace_nowrap() - .child( - Label::new( - LanguageModelRegistry::read_global(cx) - .active_model() - .map(|model| { - format!( - "{}: {}", - model.provider_name().0, - model.name().0 - ) - }) - .unwrap_or_else(|| "No model selected".into()), - ) - .size(LabelSize::Small) - .color(Color::Muted), - ), - ) - .child( - Icon::new(IconName::ChevronDown) - .color(Color::Muted) - .size(IconSize::XSmall), - ), - ) - .tooltip(move |cx| { - Tooltip::for_action("Change Model", &ToggleModelSelector, cx) - }), - ) - .with_handle(self.model_selector_menu_handle.clone()), - ) + .child(ModelSelector::new( + self.fs.clone(), + ButtonLike::new("active-model") + .style(ButtonStyle::Subtle) + .child( + h_flex() + .w_full() + .gap_0p5() + .child( + div() + .overflow_x_hidden() + .flex_grow() + .whitespace_nowrap() + .child( + Label::new( + LanguageModelRegistry::read_global(cx) + .active_model() + .map(|model| { + format!( + "{}: {}", + model.provider_name().0, + model.name().0 + ) + }) + .unwrap_or_else(|| "No model selected".into()), + ) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .child( + Icon::new(IconName::ChevronDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + ) + .tooltip(move |cx| { + Tooltip::for_action("Change Model", &ToggleModelSelector, cx) + }), + )) .children(self.render_remaining_tokens(cx)) .child(self.render_inject_context_menu(cx)); diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs index c01e6e9298..5c2025f1c2 100644 --- a/crates/assistant/src/model_selector.rs +++ b/crates/assistant/src/model_selector.rs @@ -1,21 +1,45 @@ -use std::sync::Arc; - -use crate::{assistant_settings::AssistantSettings, ShowConfiguration}; -use fs::Fs; -use gpui::{Action, SharedString}; -use language_model::{LanguageModelAvailability, LanguageModelRegistry}; +use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry}; use proto::Plan; + +use std::sync::Arc; +use ui::ListItemSpacing; + +use crate::assistant_settings::AssistantSettings; +use crate::ShowConfiguration; +use fs::Fs; +use gpui::Action; +use gpui::SharedString; +use gpui::Task; +use picker::{Picker, PickerDelegate}; use settings::update_settings_file; -use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger}; +use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger}; + +const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; #[derive(IntoElement)] pub struct ModelSelector { - handle: Option>, + handle: Option>>, fs: Arc, trigger: T, info_text: Option, } +pub struct ModelPickerDelegate { + fs: Arc, + all_models: Vec, + filtered_models: Vec, + selected_index: usize, +} + +#[derive(Clone)] +struct ModelInfo { + model: Arc, + _provider_name: SharedString, + provider_icon: IconName, + availability: LanguageModelAvailability, + is_selected: bool, +} + impl ModelSelector { pub fn new(fs: Arc, trigger: T) -> Self { ModelSelector { @@ -26,7 +50,7 @@ impl ModelSelector { } } - pub fn with_handle(mut self, handle: PopoverMenuHandle) -> Self { + pub fn with_handle(mut self, handle: PopoverMenuHandle>) -> Self { self.handle = Some(handle); self } @@ -37,148 +61,228 @@ impl ModelSelector { } } -impl RenderOnce for ModelSelector { - fn render(self, _cx: &mut WindowContext) -> impl IntoElement { - let mut menu = PopoverMenu::new("model-switcher"); - if let Some(handle) = self.handle { - menu = menu.with_handle(handle); - } +impl PickerDelegate for ModelPickerDelegate { + type ListItem = ListItem; - let info_text = self.info_text.clone(); + fn match_count(&self) -> usize { + self.filtered_models.len() + } - menu.menu(move |cx| { - ContextMenu::build(cx, |mut menu, cx| { - if let Some(info_text) = info_text.clone() { - menu = menu - .custom_row(move |_cx| { - Label::new(info_text.clone()) - .color(Color::Muted) - .into_any_element() - }) - .separator(); - } + fn selected_index(&self) -> usize { + self.selected_index + } - for (index, provider) in LanguageModelRegistry::global(cx) - .read(cx) - .providers() - .into_iter() - .enumerate() - { - let provider_icon = provider.icon(); - let provider_name = provider.name().0.clone(); + fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext>) { + self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1)); + cx.notify(); + } - if index > 0 { - menu = menu.separator(); + fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc { + "Select a model...".into() + } + + fn update_matches(&mut self, query: String, cx: &mut ViewContext>) -> Task<()> { + let all_models = self.all_models.clone(); + cx.spawn(|this, mut cx| async move { + let filtered_models = cx + .background_executor() + .spawn(async move { + if query.is_empty() { + all_models + } else { + all_models + .into_iter() + .filter(|model_info| { + model_info + .model + .name() + .0 + .to_lowercase() + .contains(&query.to_lowercase()) + }) + .collect() } - menu = menu.custom_row(move |_| { - h_flex() - .pb_1() - .gap_1p5() - .w_full() - .child( - Icon::new(provider_icon) - .color(Color::Muted) + }) + .await; + + this.update(&mut cx, |this, cx| { + this.delegate.filtered_models = filtered_models; + this.delegate.set_selected_index(0, cx); + cx.notify(); + }) + .ok(); + }) + } + + fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext>) { + if let Some(model_info) = self.filtered_models.get(self.selected_index) { + let model = model_info.model.clone(); + update_settings_file::(self.fs.clone(), cx, move |settings, _| { + settings.set_model(model.clone()) + }); + + // Update the selection status + let selected_model_id = model_info.model.id(); + for model in &mut self.all_models { + model.is_selected = model.model.id() == selected_model_id; + } + for model in &mut self.filtered_models { + model.is_selected = model.model.id() == selected_model_id; + } + } + } + + fn dismissed(&mut self, _cx: &mut ViewContext>) {} + + fn render_match( + &self, + ix: usize, + selected: bool, + cx: &mut ViewContext>, + ) -> Option { + let model_info = self.filtered_models.get(ix)?; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .selected(selected) + .start_slot( + div().pr_1().child( + Icon::new(model_info.provider_icon) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + ) + .child( + h_flex() + .w_full() + .justify_between() + .font_buffer(cx) + .min_w(px(200.)) + .child( + h_flex() + .gap_2() + .child(Label::new(model_info.model.name().0.clone())) + .children(match model_info.availability { + LanguageModelAvailability::Public => None, + LanguageModelAvailability::RequiresPlan(Plan::Free) => None, + LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => Some( + Label::new("Pro") + .size(LabelSize::XSmall) + .color(Color::Muted), + ), + }), + ) + .child(div().when(model_info.is_selected, |this| { + this.child( + Icon::new(IconName::Check) + .color(Color::Accent) .size(IconSize::Small), ) - .child(Label::new(provider_name.clone())) - .into_any_element() - }); + })), + ), + ) + } - let available_models = provider.provided_models(cx); - if available_models.is_empty() { - menu = menu.custom_entry( - { - move |_| { - h_flex() - .w_full() - .gap_1() - .child(Icon::new(IconName::Settings)) - .child(Label::new("Configure")) - .into_any() - } - }, - { - |cx| { - cx.dispatch_action(ShowConfiguration.boxed_clone()); - } - }, - ); - } + fn render_footer(&self, cx: &mut ViewContext>) -> Option { + let plan = proto::Plan::ZedPro; + let is_trial = false; - let selected_provider = LanguageModelRegistry::read_global(cx) - .active_provider() - .map(|m| m.id()); - let selected_model = LanguageModelRegistry::read_global(cx) - .active_model() - .map(|m| m.id()); - - for available_model in available_models { - menu = menu.custom_entry( - { - let id = available_model.id(); - let provider_id = available_model.provider_id(); - let model_name = available_model.name().0.clone(); - let availability = available_model.availability(); - let selected_model = selected_model.clone(); - let selected_provider = selected_provider.clone(); - move |cx| { - h_flex() - .w_full() - .justify_between() - .font_buffer(cx) - .min_w(px(260.)) - .child( - h_flex() - .gap_2() - .child(Label::new(model_name.clone())) - .children(match availability { - LanguageModelAvailability::Public => None, - LanguageModelAvailability::RequiresPlan( - Plan::Free, - ) => None, - LanguageModelAvailability::RequiresPlan( - Plan::ZedPro, - ) => Some( - Label::new("Pro") - .size(LabelSize::XSmall) - .color(Color::Muted), - ), - }), - ) - .child(div().when( - selected_model.as_ref() == Some(&id) - && selected_provider.as_ref() == Some(&provider_id), - |this| { - this.child( - Icon::new(IconName::Check) - .color(Color::Accent) - .size(IconSize::Small), - ) - }, - )) - .into_any() - } - }, - { - let fs = self.fs.clone(); - let model = available_model.clone(); - move |cx| { - let model = model.clone(); - update_settings_file::( - fs.clone(), - cx, - move |settings, _| settings.set_model(model), - ); - } - }, - ); - } - } - menu - }) - .into() - }) - .trigger(self.trigger) - .attach(gpui::AnchorCorner::BottomLeft) + Some( + h_flex() + .w_full() + .border_t_1() + .border_color(cx.theme().colors().border) + .p_1() + .gap_4() + .justify_between() + .child(match plan { + // Already a zed pro subscriber + Plan::ZedPro => Button::new("zed-pro", "Zed Pro") + .icon(IconName::ZedAssistant) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .on_click(|_, cx| { + cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings)) + }), + // Free user + Plan::Free => Button::new( + "try-pro", + if is_trial { + "Upgrade to Pro" + } else { + "Try Pro" + }, + ) + .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)), + }) + .child( + Button::new("configure", "Configure") + .icon(IconName::Settings) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .on_click(|_, cx| { + cx.dispatch_action(ShowConfiguration.boxed_clone()); + }), + ) + .into_any(), + ) + } +} + +impl RenderOnce for ModelSelector { + fn render(self, cx: &mut WindowContext) -> impl IntoElement { + let selected_provider = LanguageModelRegistry::read_global(cx) + .active_provider() + .map(|m| m.id()); + let selected_model = LanguageModelRegistry::read_global(cx) + .active_model() + .map(|m| m.id()); + + let all_models = LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .iter() + .flat_map(|provider| { + let provider_name = provider.name().0.clone(); + let provider_icon = provider.icon(); + let provider_id = provider.id(); + let selected_model = selected_model.clone(); + let selected_provider = selected_provider.clone(); + + provider.provided_models(cx).into_iter().map(move |model| { + let model = model.clone(); + + ModelInfo { + model: model.clone(), + _provider_name: provider_name.clone(), + provider_icon, + availability: model.availability(), + is_selected: selected_model.as_ref() == Some(&model.id()) + && selected_provider.as_ref() == Some(&provider_id), + } + }) + }) + .collect::>(); + + let delegate = ModelPickerDelegate { + fs: self.fs.clone(), + all_models: all_models.clone(), + filtered_models: all_models, + selected_index: 0, + }; + + let picker_view = cx.new_view(|cx| { + let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())); + picker + }); + + PopoverMenu::new("model-switcher") + .menu(move |_cx| Some(picker_view.clone())) + .trigger(self.trigger) + .attach(gpui::AnchorCorner::BottomLeft) } } diff --git a/crates/assistant/src/model_selector_old.rs b/crates/assistant/src/model_selector_old.rs new file mode 100644 index 0000000000..c01e6e9298 --- /dev/null +++ b/crates/assistant/src/model_selector_old.rs @@ -0,0 +1,184 @@ +use std::sync::Arc; + +use crate::{assistant_settings::AssistantSettings, ShowConfiguration}; +use fs::Fs; +use gpui::{Action, SharedString}; +use language_model::{LanguageModelAvailability, LanguageModelRegistry}; +use proto::Plan; +use settings::update_settings_file; +use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger}; + +#[derive(IntoElement)] +pub struct ModelSelector { + handle: Option>, + fs: Arc, + trigger: T, + info_text: Option, +} + +impl ModelSelector { + pub fn new(fs: Arc, trigger: T) -> Self { + ModelSelector { + handle: None, + fs, + trigger, + info_text: None, + } + } + + pub fn with_handle(mut self, handle: PopoverMenuHandle) -> Self { + self.handle = Some(handle); + self + } + + pub fn with_info_text(mut self, text: impl Into) -> Self { + self.info_text = Some(text.into()); + self + } +} + +impl RenderOnce for ModelSelector { + fn render(self, _cx: &mut WindowContext) -> impl IntoElement { + let mut menu = PopoverMenu::new("model-switcher"); + if let Some(handle) = self.handle { + menu = menu.with_handle(handle); + } + + let info_text = self.info_text.clone(); + + menu.menu(move |cx| { + ContextMenu::build(cx, |mut menu, cx| { + if let Some(info_text) = info_text.clone() { + menu = menu + .custom_row(move |_cx| { + Label::new(info_text.clone()) + .color(Color::Muted) + .into_any_element() + }) + .separator(); + } + + for (index, provider) in LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .into_iter() + .enumerate() + { + let provider_icon = provider.icon(); + let provider_name = provider.name().0.clone(); + + if index > 0 { + menu = menu.separator(); + } + menu = menu.custom_row(move |_| { + h_flex() + .pb_1() + .gap_1p5() + .w_full() + .child( + Icon::new(provider_icon) + .color(Color::Muted) + .size(IconSize::Small), + ) + .child(Label::new(provider_name.clone())) + .into_any_element() + }); + + let available_models = provider.provided_models(cx); + if available_models.is_empty() { + menu = menu.custom_entry( + { + move |_| { + h_flex() + .w_full() + .gap_1() + .child(Icon::new(IconName::Settings)) + .child(Label::new("Configure")) + .into_any() + } + }, + { + |cx| { + cx.dispatch_action(ShowConfiguration.boxed_clone()); + } + }, + ); + } + + let selected_provider = LanguageModelRegistry::read_global(cx) + .active_provider() + .map(|m| m.id()); + let selected_model = LanguageModelRegistry::read_global(cx) + .active_model() + .map(|m| m.id()); + + for available_model in available_models { + menu = menu.custom_entry( + { + let id = available_model.id(); + let provider_id = available_model.provider_id(); + let model_name = available_model.name().0.clone(); + let availability = available_model.availability(); + let selected_model = selected_model.clone(); + let selected_provider = selected_provider.clone(); + move |cx| { + h_flex() + .w_full() + .justify_between() + .font_buffer(cx) + .min_w(px(260.)) + .child( + h_flex() + .gap_2() + .child(Label::new(model_name.clone())) + .children(match availability { + LanguageModelAvailability::Public => None, + LanguageModelAvailability::RequiresPlan( + Plan::Free, + ) => None, + LanguageModelAvailability::RequiresPlan( + Plan::ZedPro, + ) => Some( + Label::new("Pro") + .size(LabelSize::XSmall) + .color(Color::Muted), + ), + }), + ) + .child(div().when( + selected_model.as_ref() == Some(&id) + && selected_provider.as_ref() == Some(&provider_id), + |this| { + this.child( + Icon::new(IconName::Check) + .color(Color::Accent) + .size(IconSize::Small), + ) + }, + )) + .into_any() + } + }, + { + let fs = self.fs.clone(); + let model = available_model.clone(); + move |cx| { + let model = model.clone(); + update_settings_file::( + fs.clone(), + cx, + move |settings, _| settings.set_model(model), + ); + } + }, + ); + } + } + menu + }) + .into() + }) + .trigger(self.trigger) + .attach(gpui::AnchorCorner::BottomLeft) + } +}