Add ModelSelector capability to AgentConnection

- Add ModelSelector trait to acp_thread crate with list_models, select_model, and selected_model methods
- Extend AgentConnection trait with optional model_selector() method returning Option<Rc<dyn ModelSelector>>
- Implement ModelSelector for agent2's AgentConnection using LanguageModelRegistry
- Make selected_model field mandatory on Thread struct
- Update Thread::new to require a default_model parameter
- Update agent2 to fetch default model from registry when creating threads
- Fix prompt method to use the thread's selected model directly
- All methods use &mut AsyncApp for async-friendly operations
This commit is contained in:
Nathan Sobo 2025-08-02 08:28:37 -06:00
parent 5d621bef78
commit a4fe8c6972
6 changed files with 144 additions and 11 deletions

1
Cargo.lock generated
View file

@ -19,6 +19,7 @@ dependencies = [
"indoc", "indoc",
"itertools 0.14.0", "itertools 0.14.0",
"language", "language",
"language_model",
"markdown", "markdown",
"project", "project",
"serde", "serde",

View file

@ -26,6 +26,7 @@ futures.workspace = true
gpui.workspace = true gpui.workspace = true
itertools.workspace = true itertools.workspace = true
language.workspace = true language.workspace = true
language_model.workspace = true
markdown.workspace = true markdown.workspace = true
project.workspace = true project.workspace = true
serde.workspace = true serde.workspace = true

View file

@ -1,13 +1,61 @@
use std::{error::Error, fmt, path::Path, rc::Rc}; use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use agent_client_protocol::{self as acp}; use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;
use gpui::{AsyncApp, Entity, Task}; use gpui::{AsyncApp, Entity, Task};
use language_model::LanguageModel;
use project::Project; use project::Project;
use ui::App; use ui::App;
use crate::AcpThread; use crate::AcpThread;
/// Trait for agents that support listing, selecting, and querying language models.
///
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
pub trait ModelSelector: 'static {
/// Lists all available language models for this agent.
///
/// # Parameters
/// - `cx`: The GPUI app context for async operations and global access.
///
/// # Returns
/// A task resolving to the list of models or an error (e.g., if no models are configured).
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
/// Selects a model for a specific session (thread).
///
/// This sets the default model for future interactions in the session.
/// If the session doesn't exist or the model is invalid, it returns an error.
///
/// # Parameters
/// - `session_id`: The ID of the session (thread) to apply the model to.
/// - `model`: The model to select (should be one from [list_models]).
/// - `cx`: The GPUI app context.
///
/// # Returns
/// A task resolving to `Ok(())` on success or an error.
fn select_model(
&self,
session_id: &acp::SessionId,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
) -> Task<Result<()>>;
/// Retrieves the currently selected model for a specific session (thread).
///
/// # Parameters
/// - `session_id`: The ID of the session (thread) to query.
/// - `cx`: The GPUI app context.
///
/// # Returns
/// A task resolving to the selected model (always set) or an error (e.g., session not found).
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>>;
}
pub trait AgentConnection { pub trait AgentConnection {
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
@ -23,6 +71,14 @@ pub trait AgentConnection {
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>; fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
///
/// If the agent does not support model selection, returns [None].
/// This allows sharing the selector in UI components.
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
None // Default impl for agents that don't support it
}
} }
#[derive(Debug)] #[derive(Debug)]

View file

@ -1,6 +1,8 @@
use acp_thread::ModelSelector;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::Result; use anyhow::Result;
use gpui::{App, AppContext, AsyncApp, Entity, Task}; use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::{LanguageModel, LanguageModelRegistry};
use project::Project; use project::Project;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path; use std::path::Path;
@ -26,8 +28,67 @@ impl Agent {
} }
/// Wrapper struct that implements the AgentConnection trait /// Wrapper struct that implements the AgentConnection trait
#[derive(Clone)]
pub struct AgentConnection(pub Entity<Agent>); pub struct AgentConnection(pub Entity<Agent>);
impl ModelSelector for AgentConnection {
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
let result = cx.update(|cx| {
let registry = LanguageModelRegistry::read_global(cx);
let models = registry.available_models(cx).collect::<Vec<_>>();
if models.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(models)
}
});
Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
}
fn select_model(
&self,
session_id: &acp::SessionId,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
) -> Task<Result<()>> {
let agent = self.0.clone();
let result = agent.update(cx, |agent, cx| {
if let Some(thread) = agent.sessions.get(session_id) {
thread.update(cx, |thread, _| {
thread.selected_model = model;
});
Ok(())
} else {
Err(anyhow::anyhow!("Session not found"))
}
});
Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
}
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>> {
let agent = self.0.clone();
let thread_result = agent
.read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
.ok()
.flatten()
.ok_or_else(|| anyhow::anyhow!("Session not found"));
match thread_result {
Ok(thread) => {
let selected = thread
.read_with(cx, |thread, _| thread.selected_model.clone())
.unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
Task::ready(Ok(selected))
}
Err(e) => Task::ready(Err(e)),
}
}
}
impl acp_thread::AgentConnection for AgentConnection { impl acp_thread::AgentConnection for AgentConnection {
fn new_thread( fn new_thread(
self: Rc<Self>, self: Rc<Self>,
@ -42,7 +103,13 @@ impl acp_thread::AgentConnection for AgentConnection {
// Create Thread and store in Agent // Create Thread and store in Agent
let (session_id, _thread) = let (session_id, _thread) =
agent.update(cx, |agent, cx: &mut gpui::Context<Agent>| { agent.update(cx, |agent, cx: &mut gpui::Context<Agent>| {
let thread = cx.new(|_| Thread::new(agent.templates.clone())); // Fetch default model
let default_model = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.next()
.unwrap_or_else(|| panic!("No default model available"));
let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
agent.sessions.insert(session_id.clone(), thread.clone()); agent.sessions.insert(session_id.clone(), thread.clone());
(session_id, thread) (session_id, thread)
@ -50,7 +117,9 @@ impl acp_thread::AgentConnection for AgentConnection {
// Create AcpThread // Create AcpThread
let acp_thread = cx.update(|cx| { let acp_thread = cx.update(|cx| {
cx.new(|cx| acp_thread::AcpThread::new("agent2", self, project, session_id, cx)) cx.new(|cx| {
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx)
})
})?; })?;
Ok(acp_thread) Ok(acp_thread)
@ -65,11 +134,15 @@ impl acp_thread::AgentConnection for AgentConnection {
Task::ready(Ok(())) Task::ready(Ok(()))
} }
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> { fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
let session_id = params.session_id.clone(); let session_id = params.session_id.clone();
let agent = self.0.clone(); let agent = self.0.clone();
cx.spawn(|cx| async move { cx.spawn(async move |cx| {
// Get thread // Get thread
let thread: Entity<Thread> = agent let thread: Entity<Thread> = agent
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())? .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
@ -78,13 +151,12 @@ impl acp_thread::AgentConnection for AgentConnection {
// Convert prompt to message // Convert prompt to message
let message = convert_prompt_to_message(params.prompt); let message = convert_prompt_to_message(params.prompt);
// TODO: Get model from somewhere - for now use a placeholder // Get model using the ModelSelector capability (always available for agent2)
log::warn!("Model selection not implemented - need to get from UI context"); // Get the selected model from the thread directly
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
// Send to thread // Send to thread
// thread.update(&mut cx, |thread, cx| { thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
// thread.send(model, message, cx)
// })?;
Ok(()) Ok(())
}) })

View file

@ -209,7 +209,6 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest {
cx.executor().allow_parking(); cx.executor().allow_parking();
cx.update(settings::init); cx.update(settings::init);
let templates = Templates::new(); let templates = Templates::new();
let thread = cx.new(|_| Thread::new(templates));
let model = cx let model = cx
.update(|cx| { .update(|cx| {
@ -239,6 +238,8 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest {
}) })
.await; .await;
let thread = cx.new(|_| Thread::new(templates, model.clone()));
ThreadTest { model, thread } ThreadTest { model, thread }
} }

View file

@ -37,12 +37,13 @@ pub struct Thread {
system_prompts: Vec<Arc<dyn Prompt>>, system_prompts: Vec<Arc<dyn Prompt>>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>, tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
templates: Arc<Templates>, templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
// project: Entity<Project>, // project: Entity<Project>,
// action_log: Entity<ActionLog>, // action_log: Entity<ActionLog>,
} }
impl Thread { impl Thread {
pub fn new(templates: Arc<Templates>) -> Self { pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
Self { Self {
messages: Vec::new(), messages: Vec::new(),
completion_mode: CompletionMode::Normal, completion_mode: CompletionMode::Normal,
@ -50,6 +51,7 @@ impl Thread {
running_turn: None, running_turn: None,
tools: BTreeMap::default(), tools: BTreeMap::default(),
templates, templates,
selected_model: default_model,
} }
} }