
This PR identifies automatic configuration options that users can select from the agent panel. If no default provider is set in their settings, the PR defaults to the first recommended option. Additionally, it updates the selected provider for a thread when a user changes the default provider through the settings file, if the thread hasn't had any queries yet. Release Notes: - agent: automatically select a language model provider if there's no user set provider. --------- Co-authored-by: Michael Sloan <michael@zed.dev>
767 lines
25 KiB
Rust
767 lines
25 KiB
Rust
use std::{cmp::Reverse, sync::Arc};
|
|
|
|
use cloud_llm_client::Plan;
|
|
use collections::{HashSet, IndexMap};
|
|
use feature_flags::ZedProFeatureFlag;
|
|
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
|
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
|
|
use language_model::{
|
|
ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
|
};
|
|
use ordered_float::OrderedFloat;
|
|
use picker::{Picker, PickerDelegate};
|
|
use ui::{ListItem, ListItemSpacing, prelude::*};
|
|
|
|
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
|
|
|
|
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
|
|
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
|
|
|
|
pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
|
|
|
|
pub fn language_model_selector(
|
|
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
|
|
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
|
|
window: &mut Window,
|
|
cx: &mut Context<LanguageModelSelector>,
|
|
) -> LanguageModelSelector {
|
|
let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
|
|
Picker::list(delegate, window, cx)
|
|
.show_scrollbar(true)
|
|
.width(rems(20.))
|
|
.max_height(Some(rems(20.).into()))
|
|
}
|
|
|
|
fn all_models(cx: &App) -> GroupedModels {
|
|
let providers = LanguageModelRegistry::global(cx).read(cx).providers();
|
|
|
|
let recommended = providers
|
|
.iter()
|
|
.flat_map(|provider| {
|
|
provider
|
|
.recommended_models(cx)
|
|
.into_iter()
|
|
.map(|model| ModelInfo {
|
|
model,
|
|
icon: provider.icon(),
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
let other = providers
|
|
.iter()
|
|
.flat_map(|provider| {
|
|
provider
|
|
.provided_models(cx)
|
|
.into_iter()
|
|
.map(|model| ModelInfo {
|
|
model,
|
|
icon: provider.icon(),
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
GroupedModels::new(other, recommended)
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct ModelInfo {
|
|
model: Arc<dyn LanguageModel>,
|
|
icon: IconName,
|
|
}
|
|
|
|
pub struct LanguageModelPickerDelegate {
|
|
on_model_changed: OnModelChanged,
|
|
get_active_model: GetActiveModel,
|
|
all_models: Arc<GroupedModels>,
|
|
filtered_entries: Vec<LanguageModelPickerEntry>,
|
|
selected_index: usize,
|
|
_subscriptions: Vec<Subscription>,
|
|
}
|
|
|
|
impl LanguageModelPickerDelegate {
|
|
fn new(
|
|
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
|
|
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
|
|
window: &mut Window,
|
|
cx: &mut Context<Picker<Self>>,
|
|
) -> Self {
|
|
let on_model_changed = Arc::new(on_model_changed);
|
|
let models = all_models(cx);
|
|
let entries = models.entries();
|
|
|
|
Self {
|
|
on_model_changed,
|
|
all_models: Arc::new(models),
|
|
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
|
|
filtered_entries: entries,
|
|
get_active_model: Arc::new(get_active_model),
|
|
_subscriptions: vec![cx.subscribe_in(
|
|
&LanguageModelRegistry::global(cx),
|
|
window,
|
|
|picker, _, event, window, cx| {
|
|
match event {
|
|
language_model::Event::ProviderStateChanged(_)
|
|
| language_model::Event::AddedProvider(_)
|
|
| language_model::Event::RemovedProvider(_) => {
|
|
let query = picker.query(cx);
|
|
picker.delegate.all_models = Arc::new(all_models(cx));
|
|
// Update matches will automatically drop the previous task
|
|
// if we get a provider event again
|
|
picker.update_matches(query, window, cx)
|
|
}
|
|
_ => {}
|
|
}
|
|
},
|
|
)],
|
|
}
|
|
}
|
|
|
|
fn get_active_model_index(
|
|
entries: &[LanguageModelPickerEntry],
|
|
active_model: Option<ConfiguredModel>,
|
|
) -> usize {
|
|
entries
|
|
.iter()
|
|
.position(|entry| {
|
|
if let LanguageModelPickerEntry::Model(model) = entry {
|
|
active_model
|
|
.as_ref()
|
|
.map(|active_model| {
|
|
active_model.model.id() == model.model.id()
|
|
&& active_model.provider.id() == model.model.provider_id()
|
|
})
|
|
.unwrap_or_default()
|
|
} else {
|
|
false
|
|
}
|
|
})
|
|
.unwrap_or(0)
|
|
}
|
|
|
|
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
|
|
(self.get_active_model)(cx)
|
|
}
|
|
}
|
|
|
|
struct GroupedModels {
|
|
recommended: Vec<ModelInfo>,
|
|
other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
|
|
}
|
|
|
|
impl GroupedModels {
|
|
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
|
|
let recommended_ids = recommended
|
|
.iter()
|
|
.map(|info| (info.model.provider_id(), info.model.id()))
|
|
.collect::<HashSet<_>>();
|
|
|
|
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
|
|
for model in other {
|
|
if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
|
|
continue;
|
|
}
|
|
|
|
let provider = model.model.provider_id();
|
|
if let Some(models) = other_by_provider.get_mut(&provider) {
|
|
models.push(model);
|
|
} else {
|
|
other_by_provider.insert(provider, vec![model]);
|
|
}
|
|
}
|
|
|
|
Self {
|
|
recommended,
|
|
other: other_by_provider,
|
|
}
|
|
}
|
|
|
|
fn entries(&self) -> Vec<LanguageModelPickerEntry> {
|
|
let mut entries = Vec::new();
|
|
|
|
if !self.recommended.is_empty() {
|
|
entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
|
|
entries.extend(
|
|
self.recommended
|
|
.iter()
|
|
.map(|info| LanguageModelPickerEntry::Model(info.clone())),
|
|
);
|
|
}
|
|
|
|
for models in self.other.values() {
|
|
if models.is_empty() {
|
|
continue;
|
|
}
|
|
entries.push(LanguageModelPickerEntry::Separator(
|
|
models[0].model.provider_name().0,
|
|
));
|
|
entries.extend(
|
|
models
|
|
.iter()
|
|
.map(|info| LanguageModelPickerEntry::Model(info.clone())),
|
|
);
|
|
}
|
|
entries
|
|
}
|
|
|
|
fn model_infos(&self) -> Vec<ModelInfo> {
|
|
let other = self
|
|
.other
|
|
.values()
|
|
.flat_map(|model| model.iter())
|
|
.cloned()
|
|
.collect::<Vec<_>>();
|
|
self.recommended
|
|
.iter()
|
|
.chain(&other)
|
|
.cloned()
|
|
.collect::<Vec<_>>()
|
|
}
|
|
}
|
|
|
|
enum LanguageModelPickerEntry {
|
|
Model(ModelInfo),
|
|
Separator(SharedString),
|
|
}
|
|
|
|
struct ModelMatcher {
|
|
models: Vec<ModelInfo>,
|
|
bg_executor: BackgroundExecutor,
|
|
candidates: Vec<StringMatchCandidate>,
|
|
}
|
|
|
|
impl ModelMatcher {
|
|
fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
|
|
let candidates = Self::make_match_candidates(&models);
|
|
Self {
|
|
models,
|
|
bg_executor,
|
|
candidates,
|
|
}
|
|
}
|
|
|
|
pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
|
|
let mut matches = self.bg_executor.block(match_strings(
|
|
&self.candidates,
|
|
query,
|
|
false,
|
|
true,
|
|
100,
|
|
&Default::default(),
|
|
self.bg_executor.clone(),
|
|
));
|
|
|
|
let sorting_key = |mat: &StringMatch| {
|
|
let candidate = &self.candidates[mat.candidate_id];
|
|
(Reverse(OrderedFloat(mat.score)), candidate.id)
|
|
};
|
|
matches.sort_unstable_by_key(sorting_key);
|
|
|
|
let matched_models: Vec<_> = matches
|
|
.into_iter()
|
|
.map(|mat| self.models[mat.candidate_id].clone())
|
|
.collect();
|
|
|
|
matched_models
|
|
}
|
|
|
|
pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
|
|
self.models
|
|
.iter()
|
|
.filter(|m| {
|
|
m.model
|
|
.name()
|
|
.0
|
|
.to_lowercase()
|
|
.contains(&query.to_lowercase())
|
|
})
|
|
.cloned()
|
|
.collect::<Vec<_>>()
|
|
}
|
|
|
|
fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
|
|
model_infos
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(index, model)| {
|
|
StringMatchCandidate::new(
|
|
index,
|
|
&format!(
|
|
"{}/{}",
|
|
&model.model.provider_name().0,
|
|
&model.model.name().0
|
|
),
|
|
)
|
|
})
|
|
.collect::<Vec<_>>()
|
|
}
|
|
}
|
|
|
|
impl PickerDelegate for LanguageModelPickerDelegate {
|
|
type ListItem = AnyElement;
|
|
|
|
fn match_count(&self) -> usize {
|
|
self.filtered_entries.len()
|
|
}
|
|
|
|
fn selected_index(&self) -> usize {
|
|
self.selected_index
|
|
}
|
|
|
|
fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
|
self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
|
|
cx.notify();
|
|
}
|
|
|
|
fn can_select(
|
|
&mut self,
|
|
ix: usize,
|
|
_window: &mut Window,
|
|
_cx: &mut Context<Picker<Self>>,
|
|
) -> bool {
|
|
match self.filtered_entries.get(ix) {
|
|
Some(LanguageModelPickerEntry::Model(_)) => true,
|
|
Some(LanguageModelPickerEntry::Separator(_)) | None => false,
|
|
}
|
|
}
|
|
|
|
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
|
|
"Select a model…".into()
|
|
}
|
|
|
|
fn update_matches(
|
|
&mut self,
|
|
query: String,
|
|
window: &mut Window,
|
|
cx: &mut Context<Picker<Self>>,
|
|
) -> Task<()> {
|
|
let all_models = self.all_models.clone();
|
|
let active_model = (self.get_active_model)(cx);
|
|
let bg_executor = cx.background_executor();
|
|
|
|
let language_model_registry = LanguageModelRegistry::global(cx);
|
|
|
|
let configured_providers = language_model_registry
|
|
.read(cx)
|
|
.providers()
|
|
.into_iter()
|
|
.filter(|provider| provider.is_authenticated(cx))
|
|
.collect::<Vec<_>>();
|
|
|
|
let configured_provider_ids = configured_providers
|
|
.iter()
|
|
.map(|provider| provider.id())
|
|
.collect::<Vec<_>>();
|
|
|
|
let recommended_models = all_models
|
|
.recommended
|
|
.iter()
|
|
.filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
|
|
.cloned()
|
|
.collect::<Vec<_>>();
|
|
|
|
let available_models = all_models
|
|
.model_infos()
|
|
.iter()
|
|
.filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
|
|
.cloned()
|
|
.collect::<Vec<_>>();
|
|
|
|
let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
|
|
let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
|
|
|
|
let recommended = matcher_rec.exact_search(&query);
|
|
let all = matcher_all.fuzzy_search(&query);
|
|
|
|
let filtered_models = GroupedModels::new(all, recommended);
|
|
|
|
cx.spawn_in(window, async move |this, cx| {
|
|
this.update_in(cx, |this, window, cx| {
|
|
this.delegate.filtered_entries = filtered_models.entries();
|
|
// Finds the currently selected model in the list
|
|
let new_index =
|
|
Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
|
|
this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
|
|
cx.notify();
|
|
})
|
|
.ok();
|
|
})
|
|
}
|
|
|
|
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
|
if let Some(LanguageModelPickerEntry::Model(model_info)) =
|
|
self.filtered_entries.get(self.selected_index)
|
|
{
|
|
let model = model_info.model.clone();
|
|
(self.on_model_changed)(model.clone(), cx);
|
|
|
|
let current_index = self.selected_index;
|
|
self.set_selected_index(current_index, window, cx);
|
|
|
|
cx.emit(DismissEvent);
|
|
}
|
|
}
|
|
|
|
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
|
cx.emit(DismissEvent);
|
|
}
|
|
|
|
fn render_match(
|
|
&self,
|
|
ix: usize,
|
|
selected: bool,
|
|
_: &mut Window,
|
|
cx: &mut Context<Picker<Self>>,
|
|
) -> Option<Self::ListItem> {
|
|
match self.filtered_entries.get(ix)? {
|
|
LanguageModelPickerEntry::Separator(title) => Some(
|
|
div()
|
|
.px_2()
|
|
.pb_1()
|
|
.when(ix > 1, |this| {
|
|
this.mt_1()
|
|
.pt_2()
|
|
.border_t_1()
|
|
.border_color(cx.theme().colors().border_variant)
|
|
})
|
|
.child(
|
|
Label::new(title)
|
|
.size(LabelSize::XSmall)
|
|
.color(Color::Muted),
|
|
)
|
|
.into_any_element(),
|
|
),
|
|
LanguageModelPickerEntry::Model(model_info) => {
|
|
let active_model = (self.get_active_model)(cx);
|
|
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
|
|
let active_model_id = active_model.map(|m| m.model.id());
|
|
|
|
let is_selected = Some(model_info.model.provider_id()) == active_provider_id
|
|
&& Some(model_info.model.id()) == active_model_id;
|
|
|
|
let model_icon_color = if is_selected {
|
|
Color::Accent
|
|
} else {
|
|
Color::Muted
|
|
};
|
|
|
|
Some(
|
|
ListItem::new(ix)
|
|
.inset(true)
|
|
.spacing(ListItemSpacing::Sparse)
|
|
.toggle_state(selected)
|
|
.start_slot(
|
|
Icon::new(model_info.icon)
|
|
.color(model_icon_color)
|
|
.size(IconSize::Small),
|
|
)
|
|
.child(
|
|
h_flex()
|
|
.w_full()
|
|
.pl_0p5()
|
|
.gap_1p5()
|
|
.w(px(240.))
|
|
.child(Label::new(model_info.model.name().0).truncate()),
|
|
)
|
|
.end_slot(div().pr_3().when(is_selected, |this| {
|
|
this.child(
|
|
Icon::new(IconName::Check)
|
|
.color(Color::Accent)
|
|
.size(IconSize::Small),
|
|
)
|
|
}))
|
|
.into_any_element(),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn render_footer(
|
|
&self,
|
|
_: &mut Window,
|
|
cx: &mut Context<Picker<Self>>,
|
|
) -> Option<gpui::AnyElement> {
|
|
use feature_flags::FeatureFlagAppExt;
|
|
|
|
let plan = Plan::ZedPro;
|
|
|
|
Some(
|
|
h_flex()
|
|
.w_full()
|
|
.border_t_1()
|
|
.border_color(cx.theme().colors().border_variant)
|
|
.p_1()
|
|
.gap_4()
|
|
.justify_between()
|
|
.when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
|
|
this.child(match plan {
|
|
Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
|
|
.icon(IconName::ZedAssistant)
|
|
.icon_size(IconSize::Small)
|
|
.icon_color(Color::Muted)
|
|
.icon_position(IconPosition::Start)
|
|
.on_click(|_, window, cx| {
|
|
window
|
|
.dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
|
|
}),
|
|
Plan::ZedFree | Plan::ZedProTrial => Button::new(
|
|
"try-pro",
|
|
if plan == Plan::ZedProTrial {
|
|
"Upgrade to Pro"
|
|
} else {
|
|
"Try Pro"
|
|
},
|
|
)
|
|
.on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
|
|
})
|
|
})
|
|
.child(
|
|
Button::new("configure", "Configure")
|
|
.icon(IconName::Settings)
|
|
.icon_size(IconSize::Small)
|
|
.icon_color(Color::Muted)
|
|
.icon_position(IconPosition::Start)
|
|
.on_click(|_, window, cx| {
|
|
window.dispatch_action(
|
|
zed_actions::agent::OpenSettings.boxed_clone(),
|
|
cx,
|
|
);
|
|
}),
|
|
)
|
|
.into_any(),
|
|
)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use futures::{future::BoxFuture, stream::BoxStream};
|
|
use gpui::{AsyncApp, TestAppContext, http_client};
|
|
use language_model::{
|
|
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
|
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
|
LanguageModelRequest, LanguageModelToolChoice,
|
|
};
|
|
use ui::IconName;
|
|
|
|
#[derive(Clone)]
|
|
struct TestLanguageModel {
|
|
name: LanguageModelName,
|
|
id: LanguageModelId,
|
|
provider_id: LanguageModelProviderId,
|
|
provider_name: LanguageModelProviderName,
|
|
}
|
|
|
|
impl TestLanguageModel {
|
|
fn new(name: &str, provider: &str) -> Self {
|
|
Self {
|
|
name: LanguageModelName::from(name.to_string()),
|
|
id: LanguageModelId::from(name.to_string()),
|
|
provider_id: LanguageModelProviderId::from(provider.to_string()),
|
|
provider_name: LanguageModelProviderName::from(provider.to_string()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl LanguageModel for TestLanguageModel {
|
|
fn id(&self) -> LanguageModelId {
|
|
self.id.clone()
|
|
}
|
|
|
|
fn name(&self) -> LanguageModelName {
|
|
self.name.clone()
|
|
}
|
|
|
|
fn provider_id(&self) -> LanguageModelProviderId {
|
|
self.provider_id.clone()
|
|
}
|
|
|
|
fn provider_name(&self) -> LanguageModelProviderName {
|
|
self.provider_name.clone()
|
|
}
|
|
|
|
fn supports_tools(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
|
false
|
|
}
|
|
|
|
fn supports_images(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
fn telemetry_id(&self) -> String {
|
|
format!("{}/{}", self.provider_id.0, self.name.0)
|
|
}
|
|
|
|
fn max_token_count(&self) -> u64 {
|
|
1000
|
|
}
|
|
|
|
fn count_tokens(
|
|
&self,
|
|
_: LanguageModelRequest,
|
|
_: &App,
|
|
) -> BoxFuture<'static, http_client::Result<u64>> {
|
|
unimplemented!()
|
|
}
|
|
|
|
fn stream_completion(
|
|
&self,
|
|
_: LanguageModelRequest,
|
|
_: &AsyncApp,
|
|
) -> BoxFuture<
|
|
'static,
|
|
Result<
|
|
BoxStream<
|
|
'static,
|
|
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
|
>,
|
|
LanguageModelCompletionError,
|
|
>,
|
|
> {
|
|
unimplemented!()
|
|
}
|
|
}
|
|
|
|
fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
|
|
model_specs
|
|
.into_iter()
|
|
.map(|(provider, name)| ModelInfo {
|
|
model: Arc::new(TestLanguageModel::new(name, provider)),
|
|
icon: IconName::Ai,
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
|
|
assert_eq!(
|
|
result.len(),
|
|
expected.len(),
|
|
"Number of models doesn't match"
|
|
);
|
|
|
|
for (i, expected_name) in expected.iter().enumerate() {
|
|
assert_eq!(
|
|
result[i].model.telemetry_id(),
|
|
*expected_name,
|
|
"Model at position {} doesn't match expected model",
|
|
i
|
|
);
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
fn test_exact_match(cx: &mut TestAppContext) {
|
|
let models = create_models(vec![
|
|
("zed", "Claude 3.7 Sonnet"),
|
|
("zed", "Claude 3.7 Sonnet Thinking"),
|
|
("zed", "gpt-4.1"),
|
|
("zed", "gpt-4.1-nano"),
|
|
("openai", "gpt-3.5-turbo"),
|
|
("openai", "gpt-4.1"),
|
|
("openai", "gpt-4.1-nano"),
|
|
("ollama", "mistral"),
|
|
("ollama", "deepseek"),
|
|
]);
|
|
let matcher = ModelMatcher::new(models, cx.background_executor.clone());
|
|
|
|
// The order of models should be maintained, case doesn't matter
|
|
let results = matcher.exact_search("GPT-4.1");
|
|
assert_models_eq(
|
|
results,
|
|
vec![
|
|
"zed/gpt-4.1",
|
|
"zed/gpt-4.1-nano",
|
|
"openai/gpt-4.1",
|
|
"openai/gpt-4.1-nano",
|
|
],
|
|
);
|
|
}
|
|
|
|
#[gpui::test]
|
|
fn test_fuzzy_match(cx: &mut TestAppContext) {
|
|
let models = create_models(vec![
|
|
("zed", "Claude 3.7 Sonnet"),
|
|
("zed", "Claude 3.7 Sonnet Thinking"),
|
|
("zed", "gpt-4.1"),
|
|
("zed", "gpt-4.1-nano"),
|
|
("openai", "gpt-3.5-turbo"),
|
|
("openai", "gpt-4.1"),
|
|
("openai", "gpt-4.1-nano"),
|
|
("ollama", "mistral"),
|
|
("ollama", "deepseek"),
|
|
]);
|
|
let matcher = ModelMatcher::new(models, cx.background_executor.clone());
|
|
|
|
// Results should preserve models order whenever possible.
|
|
// In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
|
|
// similarity scores, but `zed/gpt-4.1` was higher in the models list,
|
|
// so it should appear first in the results.
|
|
let results = matcher.fuzzy_search("41");
|
|
assert_models_eq(
|
|
results,
|
|
vec![
|
|
"zed/gpt-4.1",
|
|
"openai/gpt-4.1",
|
|
"zed/gpt-4.1-nano",
|
|
"openai/gpt-4.1-nano",
|
|
],
|
|
);
|
|
|
|
// Model provider should be searchable as well
|
|
let results = matcher.fuzzy_search("ol"); // meaning "ollama"
|
|
assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
|
|
|
|
// Fuzzy search
|
|
let results = matcher.fuzzy_search("z4n");
|
|
assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
|
|
let recommended_models = create_models(vec![("zed", "claude")]);
|
|
let all_models = create_models(vec![
|
|
("zed", "claude"), // Should be filtered out from "other"
|
|
("zed", "gemini"),
|
|
("copilot", "o3"),
|
|
]);
|
|
|
|
let grouped_models = GroupedModels::new(all_models, recommended_models);
|
|
|
|
let actual_other_models = grouped_models
|
|
.other
|
|
.values()
|
|
.flatten()
|
|
.cloned()
|
|
.collect::<Vec<_>>();
|
|
|
|
// Recommended models should not appear in "other"
|
|
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
|
|
}
|
|
|
|
#[gpui::test]
|
|
fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
|
|
let recommended_models = create_models(vec![("zed", "claude")]);
|
|
let all_models = create_models(vec![
|
|
("zed", "claude"), // Should be filtered out from "other"
|
|
("zed", "gemini"),
|
|
("copilot", "claude"), // Should not be filtered out from "other"
|
|
]);
|
|
|
|
let grouped_models = GroupedModels::new(all_models, recommended_models);
|
|
|
|
let actual_other_models = grouped_models
|
|
.other
|
|
.values()
|
|
.flatten()
|
|
.cloned()
|
|
.collect::<Vec<_>>();
|
|
|
|
// Recommended models should not appear in "other"
|
|
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
|
|
}
|
|
}
|