From a4fe8c69722c06f9494ccfcd1909e1f463c3e34d Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sat, 2 Aug 2025 08:28:37 -0600 Subject: [PATCH] Add ModelSelector capability to AgentConnection - Add ModelSelector trait to acp_thread crate with list_models, select_model, and selected_model methods - Extend AgentConnection trait with optional model_selector() method returning Option> - Implement ModelSelector for agent2's AgentConnection using LanguageModelRegistry - Make selected_model field mandatory on Thread struct - Update Thread::new to require a default_model parameter - Update agent2 to fetch default model from registry when creating threads - Fix prompt method to use the thread's selected model directly - All methods use &mut AsyncApp for async-friendly operations --- Cargo.lock | 1 + crates/acp_thread/Cargo.toml | 1 + crates/acp_thread/src/connection.rs | 58 ++++++++++++++++++- crates/agent2/src/agent.rs | 88 ++++++++++++++++++++++++++--- crates/agent2/src/tests/mod.rs | 3 +- crates/agent2/src/thread.rs | 4 +- 6 files changed, 144 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 705ed44d38..d81effe7a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,6 +19,7 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "language_model", "markdown", "project", "serde", diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 011f26f364..308756b038 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -26,6 +26,7 @@ futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true +language_model.workspace = true markdown.workspace = true project.workspace = true serde.workspace = true diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 929500a67b..b99e4949d8 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,13 +1,61 @@ -use std::{error::Error, fmt, path::Path, rc::Rc}; +use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use agent_client_protocol::{self as acp}; use anyhow::Result; use gpui::{AsyncApp, Entity, Task}; +use language_model::LanguageModel; use project::Project; use ui::App; use crate::AcpThread; +/// Trait for agents that support listing, selecting, and querying language models. +/// +/// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. +pub trait ModelSelector: 'static { + /// Lists all available language models for this agent. + /// + /// # Parameters + /// - `cx`: The GPUI app context for async operations and global access. + /// + /// # Returns + /// A task resolving to the list of models or an error (e.g., if no models are configured). + fn list_models(&self, cx: &mut AsyncApp) -> Task>>>; + + /// Selects a model for a specific session (thread). + /// + /// This sets the default model for future interactions in the session. + /// If the session doesn't exist or the model is invalid, it returns an error. + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to apply the model to. + /// - `model`: The model to select (should be one from [list_models]). + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to `Ok(())` on success or an error. + fn select_model( + &self, + session_id: &acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task>; + + /// Retrieves the currently selected model for a specific session (thread). + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to query. + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to the selected model (always set) or an error (e.g., session not found). + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>>; +} + pub trait AgentConnection { fn new_thread( self: Rc, @@ -23,6 +71,14 @@ pub trait AgentConnection { fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task>; fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + + /// Returns this agent as an [Rc] if the model selection capability is supported. + /// + /// If the agent does not support model selection, returns [None]. + /// This allows sharing the selector in UI components. + fn model_selector(&self) -> Option> { + None // Default impl for agents that don't support it + } } #[derive(Debug)] diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index c1c28ad41b..f10738313e 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,6 +1,8 @@ +use acp_thread::ModelSelector; use agent_client_protocol as acp; use anyhow::Result; use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use language_model::{LanguageModel, LanguageModelRegistry}; use project::Project; use std::collections::HashMap; use std::path::Path; @@ -26,8 +28,67 @@ impl Agent { } /// Wrapper struct that implements the AgentConnection trait +#[derive(Clone)] pub struct AgentConnection(pub Entity); +impl ModelSelector for AgentConnection { + fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { + let result = cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let models = registry.available_models(cx).collect::>(); + if models.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(models) + } + }); + Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e)))) + } + + fn select_model( + &self, + session_id: &acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task> { + let agent = self.0.clone(); + let result = agent.update(cx, |agent, cx| { + if let Some(thread) = agent.sessions.get(session_id) { + thread.update(cx, |thread, _| { + thread.selected_model = model; + }); + Ok(()) + } else { + Err(anyhow::anyhow!("Session not found")) + } + }); + Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e)))) + } + + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + let thread_result = agent + .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned()) + .ok() + .flatten() + .ok_or_else(|| anyhow::anyhow!("Session not found")); + + match thread_result { + Ok(thread) => { + let selected = thread + .read_with(cx, |thread, _| thread.selected_model.clone()) + .unwrap_or_else(|e| panic!("Failed to read thread: {}", e)); + Task::ready(Ok(selected)) + } + Err(e) => Task::ready(Err(e)), + } + } +} + impl acp_thread::AgentConnection for AgentConnection { fn new_thread( self: Rc, @@ -42,7 +103,13 @@ impl acp_thread::AgentConnection for AgentConnection { // Create Thread and store in Agent let (session_id, _thread) = agent.update(cx, |agent, cx: &mut gpui::Context| { - let thread = cx.new(|_| Thread::new(agent.templates.clone())); + // Fetch default model + let default_model = LanguageModelRegistry::read_global(cx) + .available_models(cx) + .next() + .unwrap_or_else(|| panic!("No default model available")); + + let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model)); let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); agent.sessions.insert(session_id.clone(), thread.clone()); (session_id, thread) @@ -50,7 +117,9 @@ impl acp_thread::AgentConnection for AgentConnection { // Create AcpThread let acp_thread = cx.update(|cx| { - cx.new(|cx| acp_thread::AcpThread::new("agent2", self, project, session_id, cx)) + cx.new(|cx| { + acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx) + }) })?; Ok(acp_thread) @@ -65,11 +134,15 @@ impl acp_thread::AgentConnection for AgentConnection { Task::ready(Ok(())) } + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let session_id = params.session_id.clone(); let agent = self.0.clone(); - cx.spawn(|cx| async move { + cx.spawn(async move |cx| { // Get thread let thread: Entity = agent .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())? @@ -78,13 +151,12 @@ impl acp_thread::AgentConnection for AgentConnection { // Convert prompt to message let message = convert_prompt_to_message(params.prompt); - // TODO: Get model from somewhere - for now use a placeholder - log::warn!("Model selection not implemented - need to get from UI context"); + // Get model using the ModelSelector capability (always available for agent2) + // Get the selected model from the thread directly + let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; // Send to thread - // thread.update(&mut cx, |thread, cx| { - // thread.send(model, message, cx) - // })?; + thread.update(cx, |thread, cx| thread.send(model, message, cx))?; Ok(()) }) diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index f3d9a35c2b..f7dc9055f6 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -209,7 +209,6 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest { cx.executor().allow_parking(); cx.update(settings::init); let templates = Templates::new(); - let thread = cx.new(|_| Thread::new(templates)); let model = cx .update(|cx| { @@ -239,6 +238,8 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest { }) .await; + let thread = cx.new(|_| Thread::new(templates, model.clone())); + ThreadTest { model, thread } } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index bc88cf1d95..758e940269 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -37,12 +37,13 @@ pub struct Thread { system_prompts: Vec>, tools: BTreeMap>, templates: Arc, + pub selected_model: Arc, // project: Entity, // action_log: Entity, } impl Thread { - pub fn new(templates: Arc) -> Self { + pub fn new(templates: Arc, default_model: Arc) -> Self { Self { messages: Vec::new(), completion_mode: CompletionMode::Normal, @@ -50,6 +51,7 @@ impl Thread { running_turn: None, tools: BTreeMap::default(), templates, + selected_model: default_model, } }