agent: Fuzzy search in model selector (#30281)
This change enables fuzzy search on model providers and names. For example, the query "z41" will match "zed/gpt-4.1". Release Notes: - Agent: Improved model selection with fuzzy search support
This commit is contained in:
parent
2c602bb0e5
commit
023a60806a
3 changed files with 336 additions and 42 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -7813,9 +7813,12 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"collections",
|
"collections",
|
||||||
"feature_flags",
|
"feature_flags",
|
||||||
|
"futures 0.3.31",
|
||||||
|
"fuzzy",
|
||||||
"gpui",
|
"gpui",
|
||||||
"language_model",
|
"language_model",
|
||||||
"log",
|
"log",
|
||||||
|
"ordered-float 2.10.1",
|
||||||
"picker",
|
"picker",
|
||||||
"proto",
|
"proto",
|
||||||
"ui",
|
"ui",
|
||||||
|
|
|
@ -11,14 +11,26 @@ workspace = true
|
||||||
[lib]
|
[lib]
|
||||||
path = "src/language_model_selector.rs"
|
path = "src/language_model_selector.rs"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
test-support = [
|
||||||
|
"gpui/test-support",
|
||||||
|
]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
feature_flags.workspace = true
|
feature_flags.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
fuzzy.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
language_model.workspace = true
|
language_model.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
|
ordered-float.workspace = true
|
||||||
picker.workspace = true
|
picker.workspace = true
|
||||||
proto.workspace = true
|
proto.workspace = true
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
zed_actions.workspace = true
|
zed_actions.workspace = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
gpui = { workspace = true, "features" = ["test-support"] }
|
||||||
|
language_model = { workspace = true, "features" = ["test-support"] }
|
||||||
|
|
|
@ -1,15 +1,18 @@
|
||||||
use std::sync::Arc;
|
use std::{cmp::Reverse, sync::Arc};
|
||||||
|
|
||||||
use collections::{HashSet, IndexMap};
|
use collections::{HashSet, IndexMap};
|
||||||
use feature_flags::ZedProFeatureFlag;
|
use feature_flags::ZedProFeatureFlag;
|
||||||
|
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
|
Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
|
||||||
Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
|
EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
|
||||||
|
action_with_deprecated_aliases,
|
||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
|
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
|
||||||
LanguageModelRegistry,
|
LanguageModelRegistry,
|
||||||
};
|
};
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
use picker::{Picker, PickerDelegate};
|
use picker::{Picker, PickerDelegate};
|
||||||
use proto::Plan;
|
use proto::Plan;
|
||||||
use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
|
use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
|
||||||
|
@ -322,6 +325,23 @@ struct GroupedModels {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GroupedModels {
|
impl GroupedModels {
|
||||||
|
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
|
||||||
|
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
|
||||||
|
for model in other {
|
||||||
|
let provider = model.model.provider_id();
|
||||||
|
if let Some(models) = other_by_provider.get_mut(&provider) {
|
||||||
|
models.push(model);
|
||||||
|
} else {
|
||||||
|
other_by_provider.insert(provider, vec![model]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
recommended,
|
||||||
|
other: other_by_provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn entries(&self) -> Vec<LanguageModelPickerEntry> {
|
fn entries(&self) -> Vec<LanguageModelPickerEntry> {
|
||||||
let mut entries = Vec::new();
|
let mut entries = Vec::new();
|
||||||
|
|
||||||
|
@ -349,6 +369,20 @@ impl GroupedModels {
|
||||||
}
|
}
|
||||||
entries
|
entries
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn model_infos(&self) -> Vec<ModelInfo> {
|
||||||
|
let other = self
|
||||||
|
.other
|
||||||
|
.values()
|
||||||
|
.flat_map(|model| model.iter())
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
self.recommended
|
||||||
|
.iter()
|
||||||
|
.chain(&other)
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enum LanguageModelPickerEntry {
|
enum LanguageModelPickerEntry {
|
||||||
|
@ -356,6 +390,78 @@ enum LanguageModelPickerEntry {
|
||||||
Separator(SharedString),
|
Separator(SharedString),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ModelMatcher {
|
||||||
|
models: Vec<ModelInfo>,
|
||||||
|
bg_executor: BackgroundExecutor,
|
||||||
|
candidates: Vec<StringMatchCandidate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelMatcher {
|
||||||
|
fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
|
||||||
|
let candidates = Self::make_match_candidates(&models);
|
||||||
|
Self {
|
||||||
|
models,
|
||||||
|
bg_executor,
|
||||||
|
candidates,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
|
||||||
|
let mut matches = self.bg_executor.block(match_strings(
|
||||||
|
&self.candidates,
|
||||||
|
&query,
|
||||||
|
false,
|
||||||
|
100,
|
||||||
|
&Default::default(),
|
||||||
|
self.bg_executor.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let sorting_key = |mat: &StringMatch| {
|
||||||
|
let candidate = &self.candidates[mat.candidate_id];
|
||||||
|
(Reverse(OrderedFloat(mat.score)), candidate.id)
|
||||||
|
};
|
||||||
|
matches.sort_unstable_by_key(sorting_key);
|
||||||
|
|
||||||
|
let matched_models: Vec<_> = matches
|
||||||
|
.into_iter()
|
||||||
|
.map(|mat| self.models[mat.candidate_id].clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
matched_models
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
|
||||||
|
self.models
|
||||||
|
.iter()
|
||||||
|
.filter(|m| {
|
||||||
|
m.model
|
||||||
|
.name()
|
||||||
|
.0
|
||||||
|
.to_lowercase()
|
||||||
|
.contains(&query.to_lowercase())
|
||||||
|
})
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
|
||||||
|
model_infos
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(index, model)| {
|
||||||
|
StringMatchCandidate::new(
|
||||||
|
index,
|
||||||
|
&format!(
|
||||||
|
"{}/{}",
|
||||||
|
&model.model.provider_name().0,
|
||||||
|
&model.model.name().0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl PickerDelegate for LanguageModelPickerDelegate {
|
impl PickerDelegate for LanguageModelPickerDelegate {
|
||||||
type ListItem = AnyElement;
|
type ListItem = AnyElement;
|
||||||
|
|
||||||
|
@ -396,56 +502,45 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||||
) -> Task<()> {
|
) -> Task<()> {
|
||||||
let all_models = self.all_models.clone();
|
let all_models = self.all_models.clone();
|
||||||
let current_index = self.selected_index;
|
let current_index = self.selected_index;
|
||||||
|
let bg_executor = cx.background_executor();
|
||||||
|
|
||||||
let language_model_registry = LanguageModelRegistry::global(cx);
|
let language_model_registry = LanguageModelRegistry::global(cx);
|
||||||
|
|
||||||
let configured_providers = language_model_registry
|
let configured_providers = language_model_registry
|
||||||
.read(cx)
|
.read(cx)
|
||||||
.providers()
|
.providers()
|
||||||
.iter()
|
.into_iter()
|
||||||
.filter(|provider| provider.is_authenticated(cx))
|
.filter(|provider| provider.is_authenticated(cx))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let configured_provider_ids = configured_providers
|
||||||
|
.iter()
|
||||||
.map(|provider| provider.id())
|
.map(|provider| provider.id())
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let recommended_models = all_models
|
||||||
|
.recommended
|
||||||
|
.iter()
|
||||||
|
.filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let available_models = all_models
|
||||||
|
.model_infos()
|
||||||
|
.iter()
|
||||||
|
.filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
|
||||||
|
let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
|
||||||
|
|
||||||
|
let recommended = matcher_rec.exact_search(&query);
|
||||||
|
let all = matcher_all.fuzzy_search(&query);
|
||||||
|
|
||||||
|
let filtered_models = GroupedModels::new(all, recommended);
|
||||||
|
|
||||||
cx.spawn_in(window, async move |this, cx| {
|
cx.spawn_in(window, async move |this, cx| {
|
||||||
let filtered_models = cx
|
|
||||||
.background_spawn(async move {
|
|
||||||
let matches = |info: &ModelInfo| {
|
|
||||||
info.model
|
|
||||||
.name()
|
|
||||||
.0
|
|
||||||
.to_lowercase()
|
|
||||||
.contains(&query.to_lowercase())
|
|
||||||
};
|
|
||||||
|
|
||||||
let recommended_models = all_models
|
|
||||||
.recommended
|
|
||||||
.iter()
|
|
||||||
.filter(|r| {
|
|
||||||
configured_providers.contains(&r.model.provider_id()) && matches(r)
|
|
||||||
})
|
|
||||||
.cloned()
|
|
||||||
.collect();
|
|
||||||
let mut other_models = IndexMap::default();
|
|
||||||
for (provider_id, models) in &all_models.other {
|
|
||||||
if configured_providers.contains(&provider_id) {
|
|
||||||
other_models.insert(
|
|
||||||
provider_id.clone(),
|
|
||||||
models
|
|
||||||
.iter()
|
|
||||||
.filter(|m| matches(m))
|
|
||||||
.cloned()
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GroupedModels {
|
|
||||||
recommended: recommended_models,
|
|
||||||
other: other_models,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
this.update_in(cx, |this, window, cx| {
|
this.update_in(cx, |this, window, cx| {
|
||||||
this.delegate.filtered_entries = filtered_models.entries();
|
this.delegate.filtered_entries = filtered_models.entries();
|
||||||
// Preserve selection focus
|
// Preserve selection focus
|
||||||
|
@ -607,3 +702,187 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use futures::{future::BoxFuture, stream::BoxStream};
|
||||||
|
use gpui::{AsyncApp, TestAppContext, http_client};
|
||||||
|
use language_model::{
|
||||||
|
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||||
|
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
|
LanguageModelRequest, LanguageModelToolChoice,
|
||||||
|
};
|
||||||
|
use ui::IconName;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct TestLanguageModel {
|
||||||
|
name: LanguageModelName,
|
||||||
|
id: LanguageModelId,
|
||||||
|
provider_id: LanguageModelProviderId,
|
||||||
|
provider_name: LanguageModelProviderName,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestLanguageModel {
|
||||||
|
fn new(name: &str, provider: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
name: LanguageModelName::from(name.to_string()),
|
||||||
|
id: LanguageModelId::from(name.to_string()),
|
||||||
|
provider_id: LanguageModelProviderId::from(provider.to_string()),
|
||||||
|
provider_name: LanguageModelProviderName::from(provider.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for TestLanguageModel {
|
||||||
|
fn id(&self) -> LanguageModelId {
|
||||||
|
self.id.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> LanguageModelName {
|
||||||
|
self.name.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_id(&self) -> LanguageModelProviderId {
|
||||||
|
self.provider_id.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn provider_name(&self) -> LanguageModelProviderName {
|
||||||
|
self.provider_name.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_tools(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn telemetry_id(&self) -> String {
|
||||||
|
format!("{}/{}", self.provider_id.0, self.name.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_token_count(&self) -> usize {
|
||||||
|
1000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_tokens(
|
||||||
|
&self,
|
||||||
|
_: LanguageModelRequest,
|
||||||
|
_: &App,
|
||||||
|
) -> BoxFuture<'static, http_client::Result<usize>> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_completion(
|
||||||
|
&self,
|
||||||
|
_: LanguageModelRequest,
|
||||||
|
_: &AsyncApp,
|
||||||
|
) -> BoxFuture<
|
||||||
|
'static,
|
||||||
|
http_client::Result<
|
||||||
|
BoxStream<
|
||||||
|
'static,
|
||||||
|
http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
|
||||||
|
model_specs
|
||||||
|
.into_iter()
|
||||||
|
.map(|(provider, name)| ModelInfo {
|
||||||
|
model: Arc::new(TestLanguageModel::new(name, provider)),
|
||||||
|
icon: IconName::Ai,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
|
||||||
|
assert_eq!(
|
||||||
|
result.len(),
|
||||||
|
expected.len(),
|
||||||
|
"Number of models doesn't match"
|
||||||
|
);
|
||||||
|
|
||||||
|
for (i, expected_name) in expected.iter().enumerate() {
|
||||||
|
assert_eq!(
|
||||||
|
result[i].model.telemetry_id(),
|
||||||
|
*expected_name,
|
||||||
|
"Model at position {} doesn't match expected model",
|
||||||
|
i
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
fn test_exact_match(cx: &mut TestAppContext) {
|
||||||
|
let models = create_models(vec![
|
||||||
|
("zed", "Claude 3.7 Sonnet"),
|
||||||
|
("zed", "Claude 3.7 Sonnet Thinking"),
|
||||||
|
("zed", "gpt-4.1"),
|
||||||
|
("zed", "gpt-4.1-nano"),
|
||||||
|
("openai", "gpt-3.5-turbo"),
|
||||||
|
("openai", "gpt-4.1"),
|
||||||
|
("openai", "gpt-4.1-nano"),
|
||||||
|
("ollama", "mistral"),
|
||||||
|
("ollama", "deepseek"),
|
||||||
|
]);
|
||||||
|
let matcher = ModelMatcher::new(models, cx.background_executor.clone());
|
||||||
|
|
||||||
|
// The order of models should be maintained, case doesn't matter
|
||||||
|
let results = matcher.exact_search("GPT-4.1");
|
||||||
|
assert_models_eq(
|
||||||
|
results,
|
||||||
|
vec![
|
||||||
|
"zed/gpt-4.1",
|
||||||
|
"zed/gpt-4.1-nano",
|
||||||
|
"openai/gpt-4.1",
|
||||||
|
"openai/gpt-4.1-nano",
|
||||||
|
],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
fn test_fuzzy_match(cx: &mut TestAppContext) {
|
||||||
|
let models = create_models(vec![
|
||||||
|
("zed", "Claude 3.7 Sonnet"),
|
||||||
|
("zed", "Claude 3.7 Sonnet Thinking"),
|
||||||
|
("zed", "gpt-4.1"),
|
||||||
|
("zed", "gpt-4.1-nano"),
|
||||||
|
("openai", "gpt-3.5-turbo"),
|
||||||
|
("openai", "gpt-4.1"),
|
||||||
|
("openai", "gpt-4.1-nano"),
|
||||||
|
("ollama", "mistral"),
|
||||||
|
("ollama", "deepseek"),
|
||||||
|
]);
|
||||||
|
let matcher = ModelMatcher::new(models, cx.background_executor.clone());
|
||||||
|
|
||||||
|
// Results should preserve models order whenever possible.
|
||||||
|
// In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
|
||||||
|
// similarity scores, but `zed/gpt-4.1` was higher in the models list,
|
||||||
|
// so it should appear first in the results.
|
||||||
|
let results = matcher.fuzzy_search("41");
|
||||||
|
assert_models_eq(
|
||||||
|
results,
|
||||||
|
vec![
|
||||||
|
"zed/gpt-4.1",
|
||||||
|
"openai/gpt-4.1",
|
||||||
|
"zed/gpt-4.1-nano",
|
||||||
|
"openai/gpt-4.1-nano",
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Model provider should be searchable as well
|
||||||
|
let results = matcher.fuzzy_search("ol"); // meaning "ollama"
|
||||||
|
assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
|
||||||
|
|
||||||
|
// Fuzzy search
|
||||||
|
let results = matcher.fuzzy_search("z4n");
|
||||||
|
assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue