From db497ac867ce8c9a2bad0aef6261ac2acb2896fa Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 13 Aug 2025 11:01:02 +0200 Subject: [PATCH] Agent2 Model Selector (#36028) Release Notes: - N/A --------- Co-authored-by: Bennet Bo Fenner --- Cargo.lock | 3 +- crates/acp_thread/Cargo.toml | 3 +- crates/acp_thread/src/acp_thread.rs | 4 + crates/acp_thread/src/connection.rs | 151 ++++-- crates/agent2/src/agent.rs | 395 ++++++++++++--- crates/agent2/src/native_agent_server.rs | 17 +- crates/agent2/src/tests/mod.rs | 42 +- crates/agent_ui/src/acp.rs | 4 + crates/agent_ui/src/acp/model_selector.rs | 472 ++++++++++++++++++ .../src/acp/model_selector_popover.rs | 85 ++++ crates/agent_ui/src/acp/thread_view.rs | 41 +- crates/agent_ui/src/agent_panel.rs | 5 +- crates/agent_ui/src/agent_ui.rs | 4 +- 13 files changed, 1078 insertions(+), 148 deletions(-) create mode 100644 crates/agent_ui/src/acp/model_selector.rs create mode 100644 crates/agent_ui/src/acp/model_selector_popover.rs diff --git a/Cargo.lock b/Cargo.lock index ffcaf64859..d31189fa06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,6 +10,7 @@ dependencies = [ "agent-client-protocol", "anyhow", "buffer_diff", + "collections", "editor", "env_logger 0.11.8", "futures 0.3.31", @@ -17,7 +18,6 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", - "language_model", "markdown", "parking_lot", "project", @@ -31,6 +31,7 @@ dependencies = [ "ui", "url", "util", + "watch", "workspace-hack", ] diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 1fef342c01..fd01b31786 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -20,12 +20,12 @@ action_log.workspace = true agent-client-protocol.workspace = true anyhow.workspace = true buffer_diff.workspace = true +collections.workspace = true editor.workspace = true futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true -language_model.workspace = true markdown.workspace = true project.workspace = true serde.workspace = true @@ -36,6 +36,7 @@ terminal.workspace = true ui.workspace = true url.workspace = true util.workspace = true +watch.workspace = true workspace-hack.workspace = true [dev-dependencies] diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 80e0a31f97..d1957e1c2a 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -694,6 +694,10 @@ impl AcpThread { } } + pub fn connection(&self) -> &Rc { + &self.connection + } + pub fn action_log(&self) -> &Entity { &self.action_log } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index cf06563bee..8e6294b3ce 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,61 +1,14 @@ -use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; +use std::{error::Error, fmt, path::Path, rc::Rc}; use agent_client_protocol::{self as acp}; use anyhow::Result; -use gpui::{AsyncApp, Entity, Task}; -use language_model::LanguageModel; +use collections::IndexMap; +use gpui::{AsyncApp, Entity, SharedString, Task}; use project::Project; -use ui::App; +use ui::{App, IconName}; use crate::AcpThread; -/// Trait for agents that support listing, selecting, and querying language models. -/// -/// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. -pub trait ModelSelector: 'static { - /// Lists all available language models for this agent. - /// - /// # Parameters - /// - `cx`: The GPUI app context for async operations and global access. - /// - /// # Returns - /// A task resolving to the list of models or an error (e.g., if no models are configured). - fn list_models(&self, cx: &mut AsyncApp) -> Task>>>; - - /// Selects a model for a specific session (thread). - /// - /// This sets the default model for future interactions in the session. - /// If the session doesn't exist or the model is invalid, it returns an error. - /// - /// # Parameters - /// - `session_id`: The ID of the session (thread) to apply the model to. - /// - `model`: The model to select (should be one from [list_models]). - /// - `cx`: The GPUI app context. - /// - /// # Returns - /// A task resolving to `Ok(())` on success or an error. - fn select_model( - &self, - session_id: acp::SessionId, - model: Arc, - cx: &mut AsyncApp, - ) -> Task>; - - /// Retrieves the currently selected model for a specific session (thread). - /// - /// # Parameters - /// - `session_id`: The ID of the session (thread) to query. - /// - `cx`: The GPUI app context. - /// - /// # Returns - /// A task resolving to the selected model (always set) or an error (e.g., session not found). - fn selected_model( - &self, - session_id: &acp::SessionId, - cx: &mut AsyncApp, - ) -> Task>>; -} - pub trait AgentConnection { fn new_thread( self: Rc, @@ -77,8 +30,8 @@ pub trait AgentConnection { /// /// If the agent does not support model selection, returns [None]. /// This allows sharing the selector in UI components. - fn model_selector(&self) -> Option> { - None // Default impl for agents that don't support it + fn model_selector(&self) -> Option> { + None } } @@ -91,3 +44,95 @@ impl fmt::Display for AuthRequired { write!(f, "AuthRequired") } } + +/// Trait for agents that support listing, selecting, and querying language models. +/// +/// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. +pub trait AgentModelSelector: 'static { + /// Lists all available language models for this agent. + /// + /// # Parameters + /// - `cx`: The GPUI app context for async operations and global access. + /// + /// # Returns + /// A task resolving to the list of models or an error (e.g., if no models are configured). + fn list_models(&self, cx: &mut App) -> Task>; + + /// Selects a model for a specific session (thread). + /// + /// This sets the default model for future interactions in the session. + /// If the session doesn't exist or the model is invalid, it returns an error. + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to apply the model to. + /// - `model`: The model to select (should be one from [list_models]). + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to `Ok(())` on success or an error. + fn select_model( + &self, + session_id: acp::SessionId, + model_id: AgentModelId, + cx: &mut App, + ) -> Task>; + + /// Retrieves the currently selected model for a specific session (thread). + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to query. + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to the selected model (always set) or an error (e.g., session not found). + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut App, + ) -> Task>; + + /// Whenever the model list is updated the receiver will be notified. + fn watch(&self, cx: &mut App) -> watch::Receiver<()>; +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct AgentModelId(pub SharedString); + +impl std::ops::Deref for AgentModelId { + type Target = SharedString; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl fmt::Display for AgentModelId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AgentModelInfo { + pub id: AgentModelId, + pub name: SharedString, + pub icon: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct AgentModelGroupName(pub SharedString); + +#[derive(Debug, Clone)] +pub enum AgentModelList { + Flat(Vec), + Grouped(IndexMap>), +} + +impl AgentModelList { + pub fn is_empty(&self) -> bool { + match self { + AgentModelList::Flat(models) => models.is_empty(), + AgentModelList::Grouped(groups) => groups.is_empty(), + } + } +} diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 7439b2a088..3ddd7be793 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -4,18 +4,22 @@ use crate::{ FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool, }; -use acp_thread::ModelSelector; +use acp_thread::AgentModelSelector; use agent_client_protocol as acp; +use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; +use collections::{HashSet, IndexMap}; +use fs::Fs; use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; -use language_model::{LanguageModel, LanguageModelRegistry}; +use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry}; use project::{Project, ProjectItem, ProjectPath, Worktree}; use prompt_store::{ ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, }; +use settings::update_settings_file; use std::cell::RefCell; use std::collections::HashMap; use std::path::Path; @@ -48,6 +52,104 @@ struct Session { _subscription: Subscription, } +pub struct LanguageModels { + /// Access language model by ID + models: HashMap>, + /// Cached list for returning language model information + model_list: acp_thread::AgentModelList, + refresh_models_rx: watch::Receiver<()>, + refresh_models_tx: watch::Sender<()>, +} + +impl LanguageModels { + fn new(cx: &App) -> Self { + let (refresh_models_tx, refresh_models_rx) = watch::channel(()); + let mut this = Self { + models: HashMap::default(), + model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()), + refresh_models_rx, + refresh_models_tx, + }; + this.refresh_list(cx); + this + } + + fn refresh_list(&mut self, cx: &App) { + let providers = LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .into_iter() + .filter(|provider| provider.is_authenticated(cx)) + .collect::>(); + + let mut language_model_list = IndexMap::default(); + let mut recommended_models = HashSet::default(); + + let mut recommended = Vec::new(); + for provider in &providers { + for model in provider.recommended_models(cx) { + recommended_models.insert(model.id()); + recommended.push(Self::map_language_model_to_info(&model, &provider)); + } + } + if !recommended.is_empty() { + language_model_list.insert( + acp_thread::AgentModelGroupName("Recommended".into()), + recommended, + ); + } + + let mut models = HashMap::default(); + for provider in providers { + let mut provider_models = Vec::new(); + for model in provider.provided_models(cx) { + let model_info = Self::map_language_model_to_info(&model, &provider); + let model_id = model_info.id.clone(); + if !recommended_models.contains(&model.id()) { + provider_models.push(model_info); + } + models.insert(model_id, model); + } + if !provider_models.is_empty() { + language_model_list.insert( + acp_thread::AgentModelGroupName(provider.name().0.clone()), + provider_models, + ); + } + } + + self.models = models; + self.model_list = acp_thread::AgentModelList::Grouped(language_model_list); + self.refresh_models_tx.send(()).ok(); + } + + fn watch(&self) -> watch::Receiver<()> { + self.refresh_models_rx.clone() + } + + pub fn model_from_id( + &self, + model_id: &acp_thread::AgentModelId, + ) -> Option> { + self.models.get(model_id).cloned() + } + + fn map_language_model_to_info( + model: &Arc, + provider: &Arc, + ) -> acp_thread::AgentModelInfo { + acp_thread::AgentModelInfo { + id: Self::model_id(model), + name: model.name().0, + icon: Some(provider.icon()), + } + } + + fn model_id(model: &Arc) -> acp_thread::AgentModelId { + acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into()) + } +} + pub struct NativeAgent { /// Session ID -> Session mapping sessions: HashMap, @@ -58,8 +160,11 @@ pub struct NativeAgent { context_server_registry: Entity, /// Shared templates for all threads templates: Arc, + /// Cached model information + models: LanguageModels, project: Entity, prompt_store: Option>, + fs: Arc, _subscriptions: Vec, } @@ -68,6 +173,7 @@ impl NativeAgent { project: Entity, templates: Arc, prompt_store: Option>, + fs: Arc, cx: &mut AsyncApp, ) -> Result> { log::info!("Creating new NativeAgent"); @@ -77,7 +183,13 @@ impl NativeAgent { .await; cx.new(|cx| { - let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; + let mut subscriptions = vec![ + cx.subscribe(&project, Self::handle_project_event), + cx.subscribe( + &LanguageModelRegistry::global(cx), + Self::handle_models_updated_event, + ), + ]; if let Some(prompt_store) = prompt_store.as_ref() { subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event)) } @@ -95,13 +207,19 @@ impl NativeAgent { ContextServerRegistry::new(project.read(cx).context_server_store(), cx) }), templates, + models: LanguageModels::new(cx), project, prompt_store, + fs, _subscriptions: subscriptions, } }) } + pub fn models(&self) -> &LanguageModels { + &self.models + } + async fn maintain_project_context( this: WeakEntity, mut needs_refresh: watch::Receiver<()>, @@ -297,75 +415,104 @@ impl NativeAgent { ) { self.project_context_needs_refresh.send(()).ok(); } + + fn handle_models_updated_event( + &mut self, + _registry: Entity, + _event: &language_model::Event, + cx: &mut Context, + ) { + self.models.refresh_list(cx); + for session in self.sessions.values_mut() { + session.thread.update(cx, |thread, _| { + let model_id = LanguageModels::model_id(&thread.selected_model); + if let Some(model) = self.models.model_from_id(&model_id) { + thread.selected_model = model.clone(); + } + }); + } + } } /// Wrapper struct that implements the AgentConnection trait #[derive(Clone)] pub struct NativeAgentConnection(pub Entity); -impl ModelSelector for NativeAgentConnection { - fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { +impl AgentModelSelector for NativeAgentConnection { + fn list_models(&self, cx: &mut App) -> Task> { log::debug!("NativeAgentConnection::list_models called"); - cx.spawn(async move |cx| { - cx.update(|cx| { - let registry = LanguageModelRegistry::read_global(cx); - let models = registry.available_models(cx).collect::>(); - log::info!("Found {} available models", models.len()); - if models.is_empty() { - Err(anyhow::anyhow!("No models available")) - } else { - Ok(models) - } - })? + let list = self.0.read(cx).models.model_list.clone(); + Task::ready(if list.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(list) }) } fn select_model( &self, session_id: acp::SessionId, - model: Arc, - cx: &mut AsyncApp, + model_id: acp_thread::AgentModelId, + cx: &mut App, ) -> Task> { - log::info!( - "Setting model for session {}: {:?}", - session_id, - model.name() - ); - let agent = self.0.clone(); + log::info!("Setting model for session {}: {}", session_id, model_id); + let Some(thread) = self + .0 + .read(cx) + .sessions + .get(&session_id) + .map(|session| session.thread.clone()) + else { + return Task::ready(Err(anyhow!("Session not found"))); + }; - cx.spawn(async move |cx| { - agent.update(cx, |agent, cx| { - if let Some(session) = agent.sessions.get(&session_id) { - session.thread.update(cx, |thread, _cx| { - thread.selected_model = model; - }); - Ok(()) - } else { - Err(anyhow!("Session not found")) - } - })? - }) + let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else { + return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); + }; + + thread.update(cx, |thread, _cx| { + thread.selected_model = model.clone(); + }); + + update_settings_file::( + self.0.read(cx).fs.clone(), + cx, + move |settings, _cx| { + settings.set_model(model); + }, + ); + + Task::ready(Ok(())) } fn selected_model( &self, session_id: &acp::SessionId, - cx: &mut AsyncApp, - ) -> Task>> { - let agent = self.0.clone(); + cx: &mut App, + ) -> Task> { let session_id = session_id.clone(); - cx.spawn(async move |cx| { - let thread = agent - .read_with(cx, |agent, _| { - agent - .sessions - .get(&session_id) - .map(|session| session.thread.clone()) - })? - .ok_or_else(|| anyhow::anyhow!("Session not found"))?; - let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; - Ok(selected) - }) + + let Some(thread) = self + .0 + .read(cx) + .sessions + .get(&session_id) + .map(|session| session.thread.clone()) + else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + let model = thread.read(cx).selected_model.clone(); + let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id()) + else { + return Task::ready(Err(anyhow!("Provider not found"))); + }; + Task::ready(Ok(LanguageModels::map_language_model_to_info( + &model, &provider, + ))) + } + + fn watch(&self, cx: &mut App) -> watch::Receiver<()> { + self.0.read(cx).models.watch() } } @@ -413,13 +560,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let default_model = registry .default_model() - .map(|configured| { - log::info!( - "Using configured default model: {:?} from provider: {:?}", - configured.model.name(), - configured.provider.name() - ); - configured.model + .and_then(|default_model| { + agent + .models + .model_from_id(&LanguageModels::model_id(&default_model.model)) }) .ok_or_else(|| { log::warn!("No default model configured in settings"); @@ -487,8 +631,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Task::ready(Ok(())) } - fn model_selector(&self) -> Option> { - Some(Rc::new(self.clone()) as Rc) + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) } fn prompt( @@ -629,6 +773,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { #[cfg(test)] mod tests { use super::*; + use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; use fs::FakeFs; use gpui::TestAppContext; use serde_json::json; @@ -646,9 +791,15 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [], cx).await; - let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async()) - .await - .unwrap(); + let agent = NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); agent.read_with(cx, |agent, _| { assert_eq!(agent.project_context.borrow().worktrees, vec![]) }); @@ -689,13 +840,131 @@ mod tests { }); } + #[gpui::test] + async fn test_listing_models(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/", json!({ "a": {} })).await; + let project = Project::test(fs.clone(), [], cx).await; + let connection = NativeAgentConnection( + NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(), + ); + + let models = cx.update(|cx| connection.list_models(cx)).await.unwrap(); + + let acp_thread::AgentModelList::Grouped(models) = models else { + panic!("Unexpected model group"); + }; + assert_eq!( + models, + IndexMap::from_iter([( + AgentModelGroupName("Fake".into()), + vec![AgentModelInfo { + id: AgentModelId("fake/fake".into()), + name: "Fake".into(), + icon: Some(ui::IconName::ZedAssistant), + }] + )]) + ); + } + + #[gpui::test] + async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.create_dir(paths::settings_file().parent().unwrap()) + .await + .unwrap(); + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "default_model": { + "provider": "foo", + "model": "bar" + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + let project = Project::test(fs.clone(), [], cx).await; + + // Create the agent and connection + let agent = NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = NativeAgentConnection(agent.clone()); + + // Create a thread/session + let acp_thread = cx + .update(|cx| { + Rc::new(connection.clone()).new_thread( + project.clone(), + Path::new("/a"), + &mut cx.to_async(), + ) + }) + .await + .unwrap(); + + let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); + + // Select a model + let model_id = AgentModelId("fake/fake".into()); + cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx)) + .await + .unwrap(); + + // Verify the thread has the selected model + agent.read_with(cx, |agent, _| { + let session = agent.sessions.get(&session_id).unwrap(); + session.thread.read_with(cx, |thread, _| { + assert_eq!(thread.selected_model.id().0, "fake"); + }); + }); + + cx.run_until_parked(); + + // Verify settings file was updated + let settings_content = fs.load(paths::settings_file()).await.unwrap(); + let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap(); + + // Check that the agent settings contain the selected model + assert_eq!( + settings_json["agent"]["default_model"]["model"], + json!("fake") + ); + assert_eq!( + settings_json["agent"]["default_model"]["provider"], + json!("fake") + ); + } + fn init_test(cx: &mut TestAppContext) { env_logger::try_init().ok(); cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); Project::init_settings(cx); + agent_settings::init(cx); language::init(cx); + LanguageModelRegistry::test(cx); }); } } diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs index 58f6d37c54..cadd88a846 100644 --- a/crates/agent2/src/native_agent_server.rs +++ b/crates/agent2/src/native_agent_server.rs @@ -1,8 +1,8 @@ -use std::path::Path; -use std::rc::Rc; +use std::{path::Path, rc::Rc, sync::Arc}; use agent_servers::AgentServer; use anyhow::Result; +use fs::Fs; use gpui::{App, Entity, Task}; use project::Project; use prompt_store::PromptStore; @@ -10,7 +10,15 @@ use prompt_store::PromptStore; use crate::{NativeAgent, NativeAgentConnection, templates::Templates}; #[derive(Clone)] -pub struct NativeAgentServer; +pub struct NativeAgentServer { + fs: Arc, +} + +impl NativeAgentServer { + pub fn new(fs: Arc) -> Self { + Self { fs } + } +} impl AgentServer for NativeAgentServer { fn name(&self) -> &'static str { @@ -41,6 +49,7 @@ impl AgentServer for NativeAgentServer { _root_dir ); let project = project.clone(); + let fs = self.fs.clone(); let prompt_store = PromptStore::global(cx); cx.spawn(async move |cx| { log::debug!("Creating templates for native agent"); @@ -48,7 +57,7 @@ impl AgentServer for NativeAgentServer { let prompt_store = prompt_store.await?; log::debug!("Creating native agent entity"); - let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?; + let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?; // Create the connection wrapper let connection = NativeAgentConnection(agent); diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 88cf92836b..b70fa56747 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,6 +1,6 @@ use super::*; use crate::MessageContent; -use acp_thread::AgentConnection; +use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList}; use action_log::ActionLog; use agent_client_protocol::{self as acp}; use agent_settings::AgentProfileId; @@ -686,13 +686,19 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Create a project for new_thread let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); fake_fs.insert_tree(path!("/test"), json!({})).await; - let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await; let cwd = Path::new("/test"); // Create agent and connection - let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async()) - .await - .unwrap(); + let agent = NativeAgent::new( + project.clone(), + templates.clone(), + None, + fake_fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); let connection = NativeAgentConnection(agent.clone()); // Test model_selector returns Some @@ -705,22 +711,22 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Test list_models let listed_models = cx - .update(|cx| { - let mut async_cx = cx.to_async(); - selector.list_models(&mut async_cx) - }) + .update(|cx| selector.list_models(cx)) .await .expect("list_models should succeed"); + let AgentModelList::Grouped(listed_models) = listed_models else { + panic!("Unexpected model list type"); + }; assert!(!listed_models.is_empty(), "should have at least one model"); - assert_eq!(listed_models[0].id().0, "fake"); + assert_eq!( + listed_models[&AgentModelGroupName("Fake".into())][0].id.0, + "fake/fake" + ); // Create a thread using new_thread let connection_rc = Rc::new(connection.clone()); let acp_thread = cx - .update(|cx| { - let mut async_cx = cx.to_async(); - connection_rc.new_thread(project, cwd, &mut async_cx) - }) + .update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async())) .await .expect("new_thread should succeed"); @@ -729,12 +735,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) { // Test selected_model returns the default let model = cx - .update(|cx| { - let mut async_cx = cx.to_async(); - selector.selected_model(&session_id, &mut async_cx) - }) + .update(|cx| selector.selected_model(&session_id, cx)) .await .expect("selected_model should succeed"); + let model = cx + .update(|cx| agent.read(cx).models().model_from_id(&model.id)) + .unwrap(); let model = model.as_fake(); assert_eq!(model.id().0, "fake", "should return default model"); diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs index cc476b1a86..b9814adb2d 100644 --- a/crates/agent_ui/src/acp.rs +++ b/crates/agent_ui/src/acp.rs @@ -1,6 +1,10 @@ mod completion_provider; mod message_history; +mod model_selector; +mod model_selector_popover; mod thread_view; pub use message_history::MessageHistory; +pub use model_selector::AcpModelSelector; +pub use model_selector_popover::AcpModelSelectorPopover; pub use thread_view::AcpThreadView; diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs new file mode 100644 index 0000000000..563afee65f --- /dev/null +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -0,0 +1,472 @@ +use std::{cmp::Reverse, rc::Rc, sync::Arc}; + +use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector}; +use agent_client_protocol as acp; +use anyhow::Result; +use collections::IndexMap; +use futures::FutureExt; +use fuzzy::{StringMatchCandidate, match_strings}; +use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity}; +use ordered_float::OrderedFloat; +use picker::{Picker, PickerDelegate}; +use ui::{ + AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window, + prelude::*, rems, +}; +use util::ResultExt; + +pub type AcpModelSelector = Picker; + +pub fn acp_model_selector( + session_id: acp::SessionId, + selector: Rc, + window: &mut Window, + cx: &mut Context, +) -> AcpModelSelector { + let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx); + Picker::list(delegate, window, cx) + .show_scrollbar(true) + .width(rems(20.)) + .max_height(Some(rems(20.).into())) +} + +enum AcpModelPickerEntry { + Separator(SharedString), + Model(AgentModelInfo), +} + +pub struct AcpModelPickerDelegate { + session_id: acp::SessionId, + selector: Rc, + filtered_entries: Vec, + models: Option, + selected_index: usize, + selected_model: Option, + _refresh_models_task: Task<()>, +} + +impl AcpModelPickerDelegate { + fn new( + session_id: acp::SessionId, + selector: Rc, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let mut rx = selector.watch(cx); + let refresh_models_task = cx.spawn_in(window, { + let session_id = session_id.clone(); + async move |this, cx| { + async fn refresh( + this: &WeakEntity>, + session_id: &acp::SessionId, + cx: &mut AsyncWindowContext, + ) -> Result<()> { + let (models_task, selected_model_task) = this.update(cx, |this, cx| { + ( + this.delegate.selector.list_models(cx), + this.delegate.selector.selected_model(session_id, cx), + ) + })?; + + let (models, selected_model) = futures::join!(models_task, selected_model_task); + + this.update_in(cx, |this, window, cx| { + this.delegate.models = models.ok(); + this.delegate.selected_model = selected_model.ok(); + this.delegate.update_matches(this.query(cx), window, cx) + })? + .await; + + Ok(()) + } + + refresh(&this, &session_id, cx).await.log_err(); + while let Ok(()) = rx.recv().await { + refresh(&this, &session_id, cx).await.log_err(); + } + } + }); + + Self { + session_id, + selector, + filtered_entries: Vec::new(), + models: None, + selected_model: None, + selected_index: 0, + _refresh_models_task: refresh_models_task, + } + } + + pub fn active_model(&self) -> Option<&AgentModelInfo> { + self.selected_model.as_ref() + } +} + +impl PickerDelegate for AcpModelPickerDelegate { + type ListItem = AnyElement; + + fn match_count(&self) -> usize { + self.filtered_entries.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context>) { + self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1)); + cx.notify(); + } + + fn can_select( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) -> bool { + match self.filtered_entries.get(ix) { + Some(AcpModelPickerEntry::Model(_)) => true, + Some(AcpModelPickerEntry::Separator(_)) | None => false, + } + } + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + "Select a model…".into() + } + + fn update_matches( + &mut self, + query: String, + window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + cx.spawn_in(window, async move |this, cx| { + let filtered_models = match this + .read_with(cx, |this, cx| { + this.delegate.models.clone().map(move |models| { + fuzzy_search(models, query, cx.background_executor().clone()) + }) + }) + .ok() + .flatten() + { + Some(task) => task.await, + None => AgentModelList::Flat(vec![]), + }; + + this.update_in(cx, |this, window, cx| { + this.delegate.filtered_entries = + info_list_to_picker_entries(filtered_models).collect(); + // Finds the currently selected model in the list + let new_index = this + .delegate + .selected_model + .as_ref() + .and_then(|selected| { + this.delegate.filtered_entries.iter().position(|entry| { + if let AcpModelPickerEntry::Model(model_info) = entry { + model_info.id == selected.id + } else { + false + } + }) + }) + .unwrap_or(0); + this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx); + cx.notify(); + }) + .ok(); + }) + } + + fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + if let Some(AcpModelPickerEntry::Model(model_info)) = + self.filtered_entries.get(self.selected_index) + { + self.selector + .select_model(self.session_id.clone(), model_info.id.clone(), cx) + .detach_and_log_err(cx); + self.selected_model = Some(model_info.clone()); + let current_index = self.selected_index; + self.set_selected_index(current_index, window, cx); + + cx.emit(DismissEvent); + } + } + + fn dismissed(&mut self, _: &mut Window, cx: &mut Context>) { + cx.emit(DismissEvent); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _: &mut Window, + cx: &mut Context>, + ) -> Option { + match self.filtered_entries.get(ix)? { + AcpModelPickerEntry::Separator(title) => Some( + div() + .px_2() + .pb_1() + .when(ix > 1, |this| { + this.mt_1() + .pt_2() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + }) + .child( + Label::new(title) + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .into_any_element(), + ), + AcpModelPickerEntry::Model(model_info) => { + let is_selected = Some(model_info) == self.selected_model.as_ref(); + + let model_icon_color = if is_selected { + Color::Accent + } else { + Color::Muted + }; + + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot::(model_info.icon.map(|icon| { + Icon::new(icon) + .color(model_icon_color) + .size(IconSize::Small) + })) + .child( + h_flex() + .w_full() + .pl_0p5() + .gap_1p5() + .w(px(240.)) + .child(Label::new(model_info.name.clone()).truncate()), + ) + .end_slot(div().pr_3().when(is_selected, |this| { + this.child( + Icon::new(IconName::Check) + .color(Color::Accent) + .size(IconSize::Small), + ) + })) + .into_any_element(), + ) + } + } + } + + fn render_footer( + &self, + _: &mut Window, + cx: &mut Context>, + ) -> Option { + Some( + h_flex() + .w_full() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + .p_1() + .gap_4() + .justify_between() + .child( + Button::new("configure", "Configure") + .icon(IconName::Settings) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .on_click(|_, window, cx| { + window.dispatch_action( + zed_actions::agent::OpenSettings.boxed_clone(), + cx, + ); + }), + ) + .into_any(), + ) + } +} + +fn info_list_to_picker_entries( + model_list: AgentModelList, +) -> impl Iterator { + match model_list { + AgentModelList::Flat(list) => { + itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model)) + } + AgentModelList::Grouped(index_map) => { + itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| { + std::iter::once(AcpModelPickerEntry::Separator(group_name.0)) + .chain(models.into_iter().map(AcpModelPickerEntry::Model)) + })) + } + } +} + +async fn fuzzy_search( + model_list: AgentModelList, + query: String, + executor: BackgroundExecutor, +) -> AgentModelList { + async fn fuzzy_search_list( + model_list: Vec, + query: &str, + executor: BackgroundExecutor, + ) -> Vec { + let candidates = model_list + .iter() + .enumerate() + .map(|(ix, model)| { + StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name)) + }) + .collect::>(); + let mut matches = match_strings( + &candidates, + &query, + false, + true, + 100, + &Default::default(), + executor, + ) + .await; + + matches.sort_unstable_by_key(|mat| { + let candidate = &candidates[mat.candidate_id]; + (Reverse(OrderedFloat(mat.score)), candidate.id) + }); + + matches + .into_iter() + .map(|mat| model_list[mat.candidate_id].clone()) + .collect() + } + + match model_list { + AgentModelList::Flat(model_list) => { + AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await) + } + AgentModelList::Grouped(index_map) => { + let groups = + futures::future::join_all(index_map.into_iter().map(|(group_name, models)| { + fuzzy_search_list(models, &query, executor.clone()) + .map(|results| (group_name, results)) + })) + .await; + AgentModelList::Grouped(IndexMap::from_iter( + groups + .into_iter() + .filter(|(_, results)| !results.is_empty()), + )) + } + } +} + +#[cfg(test)] +mod tests { + use gpui::TestAppContext; + + use super::*; + + fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList { + AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map( + |(group, models)| { + ( + acp_thread::AgentModelGroupName(group.to_string().into()), + models + .into_iter() + .map(|model| acp_thread::AgentModelInfo { + id: acp_thread::AgentModelId(model.to_string().into()), + name: model.to_string().into(), + icon: None, + }) + .collect::>(), + ) + }, + ))) + } + + fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) { + let AgentModelList::Grouped(groups) = result else { + panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result); + }; + + assert_eq!( + groups.len(), + expected.len(), + "Number of groups doesn't match" + ); + + for (i, (expected_group, expected_models)) in expected.iter().enumerate() { + let (actual_group, actual_models) = groups.get_index(i).unwrap(); + assert_eq!( + actual_group.0.as_ref(), + *expected_group, + "Group at position {} doesn't match expected group", + i + ); + assert_eq!( + actual_models.len(), + expected_models.len(), + "Number of models in group {} doesn't match", + expected_group + ); + + for (j, expected_model_name) in expected_models.iter().enumerate() { + assert_eq!( + actual_models[j].name, *expected_model_name, + "Model at position {} in group {} doesn't match expected model", + j, expected_group + ); + } + } + } + + #[gpui::test] + async fn test_fuzzy_match(cx: &mut TestAppContext) { + let models = create_model_list(vec![ + ( + "zed", + vec![ + "Claude 3.7 Sonnet", + "Claude 3.7 Sonnet Thinking", + "gpt-4.1", + "gpt-4.1-nano", + ], + ), + ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]), + ("ollama", vec!["mistral", "deepseek"]), + ]); + + // Results should preserve models order whenever possible. + // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical + // similarity scores, but `zed/gpt-4.1` was higher in the models list, + // so it should appear first in the results. + let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await; + assert_models_eq( + results, + vec![ + ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]), + ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]), + ], + ); + + // Fuzzy search + let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await; + assert_models_eq( + results, + vec![ + ("zed", vec!["gpt-4.1-nano"]), + ("openai", vec!["gpt-4.1-nano"]), + ], + ); + } +} diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs new file mode 100644 index 0000000000..e52101113a --- /dev/null +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -0,0 +1,85 @@ +use std::rc::Rc; + +use acp_thread::AgentModelSelector; +use agent_client_protocol as acp; +use gpui::{Entity, FocusHandle}; +use picker::popover_menu::PickerPopoverMenu; +use ui::{ + ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, Tooltip, Window, prelude::*, +}; +use zed_actions::agent::ToggleModelSelector; + +use crate::acp::{AcpModelSelector, model_selector::acp_model_selector}; + +pub struct AcpModelSelectorPopover { + selector: Entity, + menu_handle: PopoverMenuHandle, + focus_handle: FocusHandle, +} + +impl AcpModelSelectorPopover { + pub(crate) fn new( + session_id: acp::SessionId, + selector: Rc, + menu_handle: PopoverMenuHandle, + focus_handle: FocusHandle, + window: &mut Window, + cx: &mut Context, + ) -> Self { + Self { + selector: cx.new(move |cx| acp_model_selector(session_id, selector, window, cx)), + menu_handle, + focus_handle, + } + } + + pub fn toggle(&self, window: &mut Window, cx: &mut Context) { + self.menu_handle.toggle(window, cx); + } +} + +impl Render for AcpModelSelectorPopover { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let model = self.selector.read(cx).delegate.active_model(); + let model_name = model + .as_ref() + .map(|model| model.name.clone()) + .unwrap_or_else(|| SharedString::from("Select a Model")); + + let model_icon = model.as_ref().and_then(|model| model.icon); + + let focus_handle = self.focus_handle.clone(); + + PickerPopoverMenu::new( + self.selector.clone(), + ButtonLike::new("active-model") + .when_some(model_icon, |this, icon| { + this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall)) + }) + .child( + Label::new(model_name) + .color(Color::Muted) + .size(LabelSize::Small) + .ml_0p5(), + ) + .child( + Icon::new(IconName::ChevronDown) + .color(Color::Muted) + .size(IconSize::XSmall), + ), + move |window, cx| { + Tooltip::for_action_in( + "Change Model", + &ToggleModelSelector, + &focus_handle, + window, + cx, + ) + }, + gpui::Corner::BottomRight, + cx, + ) + .with_handle(self.menu_handle.clone()) + .render(window, cx) + } +} diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index da7915222e..12fc29b08f 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -38,12 +38,14 @@ use terminal_view::TerminalView; use text::{Anchor, BufferSnapshot}; use theme::ThemeSettings; use ui::{ - Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*, + Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, + Tooltip, prelude::*, }; use util::{ResultExt, size::format_file_size, time::duration_alt_display}; use workspace::{CollaboratorId, Workspace}; -use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; +use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage, ToggleModelSelector}; +use crate::acp::AcpModelSelectorPopover; use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; use crate::acp::message_history::MessageHistory; use crate::agent_diff::AgentDiff; @@ -63,6 +65,7 @@ pub struct AcpThreadView { diff_editors: HashMap>, terminal_views: HashMap>, message_editor: Entity, + model_selector: Option>, message_set_from_history: Option, _message_editor_subscription: Subscription, mention_set: Arc>, @@ -187,6 +190,7 @@ impl AcpThreadView { project: project.clone(), thread_state: Self::initial_state(agent, workspace, project, window, cx), message_editor, + model_selector: None, message_set_from_history: None, _message_editor_subscription: message_editor_subscription, mention_set, @@ -270,7 +274,7 @@ impl AcpThreadView { Err(e) } } - Ok(session_id) => Ok(session_id), + Ok(thread) => Ok(thread), }; this.update_in(cx, |this, window, cx| { @@ -288,6 +292,24 @@ impl AcpThreadView { AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); + this.model_selector = + thread + .read(cx) + .connection() + .model_selector() + .map(|selector| { + cx.new(|cx| { + AcpModelSelectorPopover::new( + thread.read(cx).session_id().clone(), + selector, + PopoverMenuHandle::default(), + this.focus_handle(cx), + window, + cx, + ) + }) + }); + this.thread_state = ThreadState::Ready { thread, _subscription: [thread_subscription, action_log_subscription], @@ -2472,6 +2494,12 @@ impl AcpThreadView { v_flex() .on_action(cx.listener(Self::expand_message_editor)) + .on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| { + if let Some(model_selector) = this.model_selector.as_ref() { + model_selector + .update(cx, |model_selector, cx| model_selector.toggle(window, cx)); + } + })) .p_2() .gap_2() .border_t_1() @@ -2548,7 +2576,12 @@ impl AcpThreadView { .flex_none() .justify_between() .child(self.render_follow_toggle(cx)) - .child(self.render_send_button(cx)), + .child( + h_flex() + .gap_1() + .children(self.model_selector.clone()) + .child(self.render_send_button(cx)), + ), ) .into_any() } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 87e4dd822c..d07581da93 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -916,6 +916,7 @@ impl AgentPanel { let workspace = self.workspace.clone(); let project = self.project.clone(); let message_history = self.acp_message_history.clone(); + let fs = self.fs.clone(); const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent"; @@ -939,7 +940,7 @@ impl AgentPanel { }) .detach(); - agent.server() + agent.server(fs) } None => cx .background_spawn(async move { @@ -953,7 +954,7 @@ impl AgentPanel { }) .unwrap_or_default() .agent - .server(), + .server(fs), }; this.update_in(cx, |this, window, cx| { diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index fceb8f4c45..b776c0830b 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -155,11 +155,11 @@ enum ExternalAgent { } impl ExternalAgent { - pub fn server(&self) -> Rc { + pub fn server(&self, fs: Arc) -> Rc { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), - ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer), + ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)), } } }