agent: Fuzzy search in model selector (#30281)

This change enables fuzzy search on model providers and names. For
example, the query "z41" will match "zed/gpt-4.1".

Release Notes:

- Agent: Improved model selection with fuzzy search support
This commit is contained in:
Oleksiy Syvokon 2025-05-09 14:36:29 +03:00 committed by GitHub
parent 2c602bb0e5
commit 023a60806a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 336 additions and 42 deletions

View file

@ -11,14 +11,26 @@ workspace = true
[lib]
path = "src/language_model_selector.rs"
[features]
test-support = [
"gpui/test-support",
]
[dependencies]
collections.workspace = true
feature_flags.workspace = true
futures.workspace = true
fuzzy.workspace = true
gpui.workspace = true
language_model.workspace = true
log.workspace = true
ordered-float.workspace = true
picker.workspace = true
proto.workspace = true
ui.workspace = true
workspace-hack.workspace = true
zed_actions.workspace = true
[dev-dependencies]
gpui = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }

View file

@ -1,15 +1,18 @@
use std::sync::Arc;
use std::{cmp::Reverse, sync::Arc};
use collections::{HashSet, IndexMap};
use feature_flags::ZedProFeatureFlag;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
action_with_deprecated_aliases,
};
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use proto::Plan;
use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
@ -322,6 +325,23 @@ struct GroupedModels {
}
impl GroupedModels {
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
for model in other {
let provider = model.model.provider_id();
if let Some(models) = other_by_provider.get_mut(&provider) {
models.push(model);
} else {
other_by_provider.insert(provider, vec![model]);
}
}
Self {
recommended,
other: other_by_provider,
}
}
fn entries(&self) -> Vec<LanguageModelPickerEntry> {
let mut entries = Vec::new();
@ -349,6 +369,20 @@ impl GroupedModels {
}
entries
}
fn model_infos(&self) -> Vec<ModelInfo> {
let other = self
.other
.values()
.flat_map(|model| model.iter())
.cloned()
.collect::<Vec<_>>();
self.recommended
.iter()
.chain(&other)
.cloned()
.collect::<Vec<_>>()
}
}
enum LanguageModelPickerEntry {
@ -356,6 +390,78 @@ enum LanguageModelPickerEntry {
Separator(SharedString),
}
struct ModelMatcher {
models: Vec<ModelInfo>,
bg_executor: BackgroundExecutor,
candidates: Vec<StringMatchCandidate>,
}
impl ModelMatcher {
fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
let candidates = Self::make_match_candidates(&models);
Self {
models,
bg_executor,
candidates,
}
}
pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
let mut matches = self.bg_executor.block(match_strings(
&self.candidates,
&query,
false,
100,
&Default::default(),
self.bg_executor.clone(),
));
let sorting_key = |mat: &StringMatch| {
let candidate = &self.candidates[mat.candidate_id];
(Reverse(OrderedFloat(mat.score)), candidate.id)
};
matches.sort_unstable_by_key(sorting_key);
let matched_models: Vec<_> = matches
.into_iter()
.map(|mat| self.models[mat.candidate_id].clone())
.collect();
matched_models
}
pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
self.models
.iter()
.filter(|m| {
m.model
.name()
.0
.to_lowercase()
.contains(&query.to_lowercase())
})
.cloned()
.collect::<Vec<_>>()
}
fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
model_infos
.iter()
.enumerate()
.map(|(index, model)| {
StringMatchCandidate::new(
index,
&format!(
"{}/{}",
&model.model.provider_name().0,
&model.model.name().0
),
)
})
.collect::<Vec<_>>()
}
}
impl PickerDelegate for LanguageModelPickerDelegate {
type ListItem = AnyElement;
@ -396,56 +502,45 @@ impl PickerDelegate for LanguageModelPickerDelegate {
) -> Task<()> {
let all_models = self.all_models.clone();
let current_index = self.selected_index;
let bg_executor = cx.background_executor();
let language_model_registry = LanguageModelRegistry::global(cx);
let configured_providers = language_model_registry
.read(cx)
.providers()
.iter()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
let configured_provider_ids = configured_providers
.iter()
.map(|provider| provider.id())
.collect::<Vec<_>>();
let recommended_models = all_models
.recommended
.iter()
.filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
.cloned()
.collect::<Vec<_>>();
let available_models = all_models
.model_infos()
.iter()
.filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
.cloned()
.collect::<Vec<_>>();
let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
let recommended = matcher_rec.exact_search(&query);
let all = matcher_all.fuzzy_search(&query);
let filtered_models = GroupedModels::new(all, recommended);
cx.spawn_in(window, async move |this, cx| {
let filtered_models = cx
.background_spawn(async move {
let matches = |info: &ModelInfo| {
info.model
.name()
.0
.to_lowercase()
.contains(&query.to_lowercase())
};
let recommended_models = all_models
.recommended
.iter()
.filter(|r| {
configured_providers.contains(&r.model.provider_id()) && matches(r)
})
.cloned()
.collect();
let mut other_models = IndexMap::default();
for (provider_id, models) in &all_models.other {
if configured_providers.contains(&provider_id) {
other_models.insert(
provider_id.clone(),
models
.iter()
.filter(|m| matches(m))
.cloned()
.collect::<Vec<_>>(),
);
}
}
GroupedModels {
recommended: recommended_models,
other: other_models,
}
})
.await;
this.update_in(cx, |this, window, cx| {
this.delegate.filtered_entries = filtered_models.entries();
// Preserve selection focus
@ -607,3 +702,187 @@ impl PickerDelegate for LanguageModelPickerDelegate {
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AsyncApp, TestAppContext, http_client};
use language_model::{
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRequest, LanguageModelToolChoice,
};
use ui::IconName;
#[derive(Clone)]
struct TestLanguageModel {
name: LanguageModelName,
id: LanguageModelId,
provider_id: LanguageModelProviderId,
provider_name: LanguageModelProviderName,
}
impl TestLanguageModel {
fn new(name: &str, provider: &str) -> Self {
Self {
name: LanguageModelName::from(name.to_string()),
id: LanguageModelId::from(name.to_string()),
provider_id: LanguageModelProviderId::from(provider.to_string()),
provider_name: LanguageModelProviderName::from(provider.to_string()),
}
}
}
impl LanguageModel for TestLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
self.name.clone()
}
fn provider_id(&self) -> LanguageModelProviderId {
self.provider_id.clone()
}
fn provider_name(&self) -> LanguageModelProviderName {
self.provider_name.clone()
}
fn supports_tools(&self) -> bool {
false
}
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false
}
fn telemetry_id(&self) -> String {
format!("{}/{}", self.provider_id.0, self.name.0)
}
fn max_token_count(&self) -> usize {
1000
}
fn count_tokens(
&self,
_: LanguageModelRequest,
_: &App,
) -> BoxFuture<'static, http_client::Result<usize>> {
unimplemented!()
}
fn stream_completion(
&self,
_: LanguageModelRequest,
_: &AsyncApp,
) -> BoxFuture<
'static,
http_client::Result<
BoxStream<
'static,
http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
>,
> {
unimplemented!()
}
}
fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
model_specs
.into_iter()
.map(|(provider, name)| ModelInfo {
model: Arc::new(TestLanguageModel::new(name, provider)),
icon: IconName::Ai,
})
.collect()
}
fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
assert_eq!(
result.len(),
expected.len(),
"Number of models doesn't match"
);
for (i, expected_name) in expected.iter().enumerate() {
assert_eq!(
result[i].model.telemetry_id(),
*expected_name,
"Model at position {} doesn't match expected model",
i
);
}
}
#[gpui::test]
fn test_exact_match(cx: &mut TestAppContext) {
let models = create_models(vec![
("zed", "Claude 3.7 Sonnet"),
("zed", "Claude 3.7 Sonnet Thinking"),
("zed", "gpt-4.1"),
("zed", "gpt-4.1-nano"),
("openai", "gpt-3.5-turbo"),
("openai", "gpt-4.1"),
("openai", "gpt-4.1-nano"),
("ollama", "mistral"),
("ollama", "deepseek"),
]);
let matcher = ModelMatcher::new(models, cx.background_executor.clone());
// The order of models should be maintained, case doesn't matter
let results = matcher.exact_search("GPT-4.1");
assert_models_eq(
results,
vec![
"zed/gpt-4.1",
"zed/gpt-4.1-nano",
"openai/gpt-4.1",
"openai/gpt-4.1-nano",
],
);
}
#[gpui::test]
fn test_fuzzy_match(cx: &mut TestAppContext) {
let models = create_models(vec![
("zed", "Claude 3.7 Sonnet"),
("zed", "Claude 3.7 Sonnet Thinking"),
("zed", "gpt-4.1"),
("zed", "gpt-4.1-nano"),
("openai", "gpt-3.5-turbo"),
("openai", "gpt-4.1"),
("openai", "gpt-4.1-nano"),
("ollama", "mistral"),
("ollama", "deepseek"),
]);
let matcher = ModelMatcher::new(models, cx.background_executor.clone());
// Results should preserve models order whenever possible.
// In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
// similarity scores, but `zed/gpt-4.1` was higher in the models list,
// so it should appear first in the results.
let results = matcher.fuzzy_search("41");
assert_models_eq(
results,
vec![
"zed/gpt-4.1",
"openai/gpt-4.1",
"zed/gpt-4.1-nano",
"openai/gpt-4.1-nano",
],
);
// Model provider should be searchable as well
let results = matcher.fuzzy_search("ol"); // meaning "ollama"
assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
// Fuzzy search
let results = matcher.fuzzy_search("z4n");
assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
}
}