From 187356ab9b6330d0e8260e6886562064c52e808d Mon Sep 17 00:00:00 2001
From: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>
Date: Fri, 8 Nov 2024 10:08:59 -0300
Subject: [PATCH] assistant: Show only configured models in the model picker
(#20392)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Closes https://github.com/zed-industries/zed/issues/16568
This PR introduces some changes to how we display models in the model
selector within the assistant panel. Basically, it comes down to this:
- If you don't have any provider configured, you should see _all_
available models in the picker
- But, once you've configured some, you should _only_ see models from
them in the picker
Visually, nothing's changed much aside from the added "Configured
Models" label at the top to ensure the understanding that that's a list
of, well, configured models only. 😬
Release Notes:
- Change model selector in the assistant panel to only show configured
models
---
crates/assistant/src/assistant_panel.rs | 2 +-
crates/assistant/src/model_selector.rs | 110 ++++++++++++++------
crates/language_model/src/provider/cloud.rs | 29 ++----
3 files changed, 88 insertions(+), 53 deletions(-)
diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs
index 6900a536e9..61afa60154 100644
--- a/crates/assistant/src/assistant_panel.rs
+++ b/crates/assistant/src/assistant_panel.rs
@@ -4507,7 +4507,6 @@ impl Render for ContextEditorToolbarItem {
fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement {
let left_side = h_flex()
.group("chat-title-group")
- .pl_0p5()
.gap_1()
.items_center()
.flex_grow()
@@ -4598,6 +4597,7 @@ impl Render for ContextEditorToolbarItem {
.children(self.render_remaining_tokens(cx));
h_flex()
+ .px_0p5()
.size_full()
.gap_2()
.justify_between()
diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs
index c9fbdd36c3..1b26b8b5ad 100644
--- a/crates/assistant/src/model_selector.rs
+++ b/crates/assistant/src/model_selector.rs
@@ -1,21 +1,17 @@
use feature_flags::ZedPro;
-use gpui::Action;
-use gpui::DismissEvent;
use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
use proto::Plan;
use workspace::ShowConfiguration;
use std::sync::Arc;
-use ui::ListItemSpacing;
use crate::assistant_settings::AssistantSettings;
use fs::Fs;
-use gpui::SharedString;
-use gpui::Task;
+use gpui::{Action, AnyElement, DismissEvent, SharedString, Task};
use picker::{Picker, PickerDelegate};
use settings::update_settings_file;
-use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
+use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
@@ -85,14 +81,36 @@ impl PickerDelegate for ModelPickerDelegate {
fn update_matches(&mut self, query: String, cx: &mut ViewContext>) -> Task<()> {
let all_models = self.all_models.clone();
+
+ let llm_registry = LanguageModelRegistry::global(cx);
+
+ let configured_models: Vec<_> = llm_registry
+ .read(cx)
+ .providers()
+ .iter()
+ .filter(|provider| provider.is_authenticated(cx))
+ .map(|provider| provider.id())
+ .collect();
+
cx.spawn(|this, mut cx| async move {
let filtered_models = cx
.background_executor()
.spawn(async move {
- if query.is_empty() {
+ let displayed_models = if configured_models.is_empty() {
all_models
} else {
all_models
+ .into_iter()
+ .filter(|model_info| {
+ configured_models.contains(&model_info.model.provider_id())
+ })
+ .collect::>()
+ };
+
+ if query.is_empty() {
+ displayed_models
+ } else {
+ displayed_models
.into_iter()
.filter(|model_info| {
model_info
@@ -141,6 +159,29 @@ impl PickerDelegate for ModelPickerDelegate {
fn dismissed(&mut self, _cx: &mut ViewContext>) {}
+ fn render_header(&self, cx: &mut ViewContext>) -> Option {
+ let configured_models_count = LanguageModelRegistry::global(cx)
+ .read(cx)
+ .providers()
+ .iter()
+ .filter(|provider| provider.is_authenticated(cx))
+ .count();
+
+ if configured_models_count > 0 {
+ Some(
+ Label::new("Configured Models")
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .mt_1()
+ .mb_0p5()
+ .ml_3()
+ .into_any_element(),
+ )
+ } else {
+ None
+ }
+ }
+
fn render_match(
&self,
ix: usize,
@@ -148,9 +189,10 @@ impl PickerDelegate for ModelPickerDelegate {
cx: &mut ViewContext>,
) -> Option {
use feature_flags::FeatureFlagAppExt;
- let model_info = self.filtered_models.get(ix)?;
let show_badges = cx.has_flag::();
- let provider_name: String = model_info.model.provider_name().0.into();
+
+ let model_info = self.filtered_models.get(ix)?;
+ let provider_name: String = model_info.model.provider_name().0.clone().into();
Some(
ListItem::new(ix)
@@ -165,27 +207,32 @@ impl PickerDelegate for ModelPickerDelegate {
),
)
.child(
- h_flex().w_full().justify_between().min_w(px(200.)).child(
- h_flex()
- .gap_1p5()
- .child(Label::new(model_info.model.name().0.clone()))
- .child(
- Label::new(provider_name)
- .size(LabelSize::XSmall)
- .color(Color::Muted),
- )
- .children(match model_info.availability {
- LanguageModelAvailability::Public => None,
- LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
- LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
- show_badges.then(|| {
- Label::new("Pro")
- .size(LabelSize::XSmall)
- .color(Color::Muted)
- })
- }
- }),
- ),
+ h_flex()
+ .w_full()
+ .items_center()
+ .gap_1p5()
+ .min_w(px(200.))
+ .child(Label::new(model_info.model.name().0.clone()))
+ .child(
+ h_flex()
+ .gap_0p5()
+ .child(
+ Label::new(provider_name)
+ .size(LabelSize::XSmall)
+ .color(Color::Muted),
+ )
+ .children(match model_info.availability {
+ LanguageModelAvailability::Public => None,
+ LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
+ LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
+ show_badges.then(|| {
+ Label::new("Pro")
+ .size(LabelSize::XSmall)
+ .color(Color::Muted)
+ })
+ }
+ }),
+ ),
)
.end_slot(div().when(model_info.is_selected, |this| {
this.child(
@@ -213,7 +260,7 @@ impl PickerDelegate for ModelPickerDelegate {
.justify_between()
.when(cx.has_flag::(), |this| {
this.child(match plan {
- // Already a zed pro subscriber
+ // Already a Zed Pro subscriber
Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
.icon(IconName::ZedAssistant)
.icon_size(IconSize::Small)
@@ -254,6 +301,7 @@ impl RenderOnce for ModelSelector {
let selected_provider = LanguageModelRegistry::read_global(cx)
.active_provider()
.map(|m| m.id());
+
let selected_model = LanguageModelRegistry::read_global(cx)
.active_model()
.map(|m| m.id());
diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs
index b81f8e6881..971f2d3113 100644
--- a/crates/language_model/src/provider/cloud.rs
+++ b/crates/language_model/src/provider/cloud.rs
@@ -912,7 +912,7 @@ impl Render for ConfigurationView {
let is_pro = plan == Some(proto::Plan::ZedPro);
let subscription_text = Label::new(if is_pro {
- "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
+ "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
} else {
"You have basic access to models from Anthropic through the Zed AI Free plan."
});
@@ -957,27 +957,14 @@ impl Render for ConfigurationView {
})
} else {
v_flex()
- .gap_6()
- .child(Label::new("Use the zed.dev to access language models."))
+ .gap_2()
+ .child(Label::new("Use Zed AI to access hosted language models."))
.child(
- v_flex()
- .gap_2()
- .child(
- Button::new("sign_in", "Sign in")
- .icon_color(Color::Muted)
- .icon(IconName::Github)
- .icon_position(IconPosition::Start)
- .style(ButtonStyle::Filled)
- .full_width()
- .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
- )
- .child(
- div().flex().w_full().items_center().child(
- Label::new("Sign in to enable collaboration.")
- .color(Color::Muted)
- .size(LabelSize::Small),
- ),
- ),
+ Button::new("sign_in", "Sign In")
+ .icon_color(Color::Muted)
+ .icon(IconName::Github)
+ .icon_position(IconPosition::Start)
+ .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
)
}
}