Add ModelSelector capability to AgentConnection
- Add ModelSelector trait to acp_thread crate with list_models, select_model, and selected_model methods - Extend AgentConnection trait with optional model_selector() method returning Option<Rc<dyn ModelSelector>> - Implement ModelSelector for agent2's AgentConnection using LanguageModelRegistry - Make selected_model field mandatory on Thread struct - Update Thread::new to require a default_model parameter - Update agent2 to fetch default model from registry when creating threads - Fix prompt method to use the thread's selected model directly - All methods use &mut AsyncApp for async-friendly operations
This commit is contained in:
parent
5d621bef78
commit
a4fe8c6972
6 changed files with 144 additions and 11 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -19,6 +19,7 @@ dependencies = [
|
||||||
"indoc",
|
"indoc",
|
||||||
"itertools 0.14.0",
|
"itertools 0.14.0",
|
||||||
"language",
|
"language",
|
||||||
|
"language_model",
|
||||||
"markdown",
|
"markdown",
|
||||||
"project",
|
"project",
|
||||||
"serde",
|
"serde",
|
||||||
|
|
|
@ -26,6 +26,7 @@ futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
itertools.workspace = true
|
itertools.workspace = true
|
||||||
language.workspace = true
|
language.workspace = true
|
||||||
|
language_model.workspace = true
|
||||||
markdown.workspace = true
|
markdown.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
|
|
|
@ -1,13 +1,61 @@
|
||||||
use std::{error::Error, fmt, path::Path, rc::Rc};
|
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||||
|
|
||||||
use agent_client_protocol::{self as acp};
|
use agent_client_protocol::{self as acp};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use gpui::{AsyncApp, Entity, Task};
|
use gpui::{AsyncApp, Entity, Task};
|
||||||
|
use language_model::LanguageModel;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use ui::App;
|
use ui::App;
|
||||||
|
|
||||||
use crate::AcpThread;
|
use crate::AcpThread;
|
||||||
|
|
||||||
|
/// Trait for agents that support listing, selecting, and querying language models.
|
||||||
|
///
|
||||||
|
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
|
||||||
|
pub trait ModelSelector: 'static {
|
||||||
|
/// Lists all available language models for this agent.
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
/// - `cx`: The GPUI app context for async operations and global access.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// A task resolving to the list of models or an error (e.g., if no models are configured).
|
||||||
|
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
|
||||||
|
|
||||||
|
/// Selects a model for a specific session (thread).
|
||||||
|
///
|
||||||
|
/// This sets the default model for future interactions in the session.
|
||||||
|
/// If the session doesn't exist or the model is invalid, it returns an error.
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
/// - `session_id`: The ID of the session (thread) to apply the model to.
|
||||||
|
/// - `model`: The model to select (should be one from [list_models]).
|
||||||
|
/// - `cx`: The GPUI app context.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// A task resolving to `Ok(())` on success or an error.
|
||||||
|
fn select_model(
|
||||||
|
&self,
|
||||||
|
session_id: &acp::SessionId,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Task<Result<()>>;
|
||||||
|
|
||||||
|
/// Retrieves the currently selected model for a specific session (thread).
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
/// - `session_id`: The ID of the session (thread) to query.
|
||||||
|
/// - `cx`: The GPUI app context.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// A task resolving to the selected model (always set) or an error (e.g., session not found).
|
||||||
|
fn selected_model(
|
||||||
|
&self,
|
||||||
|
session_id: &acp::SessionId,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Task<Result<Arc<dyn LanguageModel>>>;
|
||||||
|
}
|
||||||
|
|
||||||
pub trait AgentConnection {
|
pub trait AgentConnection {
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
|
@ -23,6 +71,14 @@ pub trait AgentConnection {
|
||||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||||
|
|
||||||
|
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||||
|
///
|
||||||
|
/// If the agent does not support model selection, returns [None].
|
||||||
|
/// This allows sharing the selector in UI components.
|
||||||
|
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||||
|
None // Default impl for agents that don't support it
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
use acp_thread::ModelSelector;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||||
|
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
@ -26,8 +28,67 @@ impl Agent {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wrapper struct that implements the AgentConnection trait
|
/// Wrapper struct that implements the AgentConnection trait
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct AgentConnection(pub Entity<Agent>);
|
pub struct AgentConnection(pub Entity<Agent>);
|
||||||
|
|
||||||
|
impl ModelSelector for AgentConnection {
|
||||||
|
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
||||||
|
let result = cx.update(|cx| {
|
||||||
|
let registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||||
|
if models.is_empty() {
|
||||||
|
Err(anyhow::anyhow!("No models available"))
|
||||||
|
} else {
|
||||||
|
Ok(models)
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn select_model(
|
||||||
|
&self,
|
||||||
|
session_id: &acp::SessionId,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Task<Result<()>> {
|
||||||
|
let agent = self.0.clone();
|
||||||
|
let result = agent.update(cx, |agent, cx| {
|
||||||
|
if let Some(thread) = agent.sessions.get(session_id) {
|
||||||
|
thread.update(cx, |thread, _| {
|
||||||
|
thread.selected_model = model;
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(anyhow::anyhow!("Session not found"))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn selected_model(
|
||||||
|
&self,
|
||||||
|
session_id: &acp::SessionId,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Task<Result<Arc<dyn LanguageModel>>> {
|
||||||
|
let agent = self.0.clone();
|
||||||
|
let thread_result = agent
|
||||||
|
.read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
|
||||||
|
.ok()
|
||||||
|
.flatten()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Session not found"));
|
||||||
|
|
||||||
|
match thread_result {
|
||||||
|
Ok(thread) => {
|
||||||
|
let selected = thread
|
||||||
|
.read_with(cx, |thread, _| thread.selected_model.clone())
|
||||||
|
.unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
|
||||||
|
Task::ready(Ok(selected))
|
||||||
|
}
|
||||||
|
Err(e) => Task::ready(Err(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl acp_thread::AgentConnection for AgentConnection {
|
impl acp_thread::AgentConnection for AgentConnection {
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
|
@ -42,7 +103,13 @@ impl acp_thread::AgentConnection for AgentConnection {
|
||||||
// Create Thread and store in Agent
|
// Create Thread and store in Agent
|
||||||
let (session_id, _thread) =
|
let (session_id, _thread) =
|
||||||
agent.update(cx, |agent, cx: &mut gpui::Context<Agent>| {
|
agent.update(cx, |agent, cx: &mut gpui::Context<Agent>| {
|
||||||
let thread = cx.new(|_| Thread::new(agent.templates.clone()));
|
// Fetch default model
|
||||||
|
let default_model = LanguageModelRegistry::read_global(cx)
|
||||||
|
.available_models(cx)
|
||||||
|
.next()
|
||||||
|
.unwrap_or_else(|| panic!("No default model available"));
|
||||||
|
|
||||||
|
let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
|
||||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
||||||
agent.sessions.insert(session_id.clone(), thread.clone());
|
agent.sessions.insert(session_id.clone(), thread.clone());
|
||||||
(session_id, thread)
|
(session_id, thread)
|
||||||
|
@ -50,7 +117,9 @@ impl acp_thread::AgentConnection for AgentConnection {
|
||||||
|
|
||||||
// Create AcpThread
|
// Create AcpThread
|
||||||
let acp_thread = cx.update(|cx| {
|
let acp_thread = cx.update(|cx| {
|
||||||
cx.new(|cx| acp_thread::AcpThread::new("agent2", self, project, session_id, cx))
|
cx.new(|cx| {
|
||||||
|
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx)
|
||||||
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(acp_thread)
|
Ok(acp_thread)
|
||||||
|
@ -65,11 +134,15 @@ impl acp_thread::AgentConnection for AgentConnection {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||||
|
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
|
||||||
|
}
|
||||||
|
|
||||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||||
let session_id = params.session_id.clone();
|
let session_id = params.session_id.clone();
|
||||||
let agent = self.0.clone();
|
let agent = self.0.clone();
|
||||||
|
|
||||||
cx.spawn(|cx| async move {
|
cx.spawn(async move |cx| {
|
||||||
// Get thread
|
// Get thread
|
||||||
let thread: Entity<Thread> = agent
|
let thread: Entity<Thread> = agent
|
||||||
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
|
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
|
||||||
|
@ -78,13 +151,12 @@ impl acp_thread::AgentConnection for AgentConnection {
|
||||||
// Convert prompt to message
|
// Convert prompt to message
|
||||||
let message = convert_prompt_to_message(params.prompt);
|
let message = convert_prompt_to_message(params.prompt);
|
||||||
|
|
||||||
// TODO: Get model from somewhere - for now use a placeholder
|
// Get model using the ModelSelector capability (always available for agent2)
|
||||||
log::warn!("Model selection not implemented - need to get from UI context");
|
// Get the selected model from the thread directly
|
||||||
|
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||||
|
|
||||||
// Send to thread
|
// Send to thread
|
||||||
// thread.update(&mut cx, |thread, cx| {
|
thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
|
||||||
// thread.send(model, message, cx)
|
|
||||||
// })?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
|
|
|
@ -209,7 +209,6 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest {
|
||||||
cx.executor().allow_parking();
|
cx.executor().allow_parking();
|
||||||
cx.update(settings::init);
|
cx.update(settings::init);
|
||||||
let templates = Templates::new();
|
let templates = Templates::new();
|
||||||
let thread = cx.new(|_| Thread::new(templates));
|
|
||||||
|
|
||||||
let model = cx
|
let model = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
|
@ -239,6 +238,8 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest {
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
let thread = cx.new(|_| Thread::new(templates, model.clone()));
|
||||||
|
|
||||||
ThreadTest { model, thread }
|
ThreadTest { model, thread }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,12 +37,13 @@ pub struct Thread {
|
||||||
system_prompts: Vec<Arc<dyn Prompt>>,
|
system_prompts: Vec<Arc<dyn Prompt>>,
|
||||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
templates: Arc<Templates>,
|
templates: Arc<Templates>,
|
||||||
|
pub selected_model: Arc<dyn LanguageModel>,
|
||||||
// project: Entity<Project>,
|
// project: Entity<Project>,
|
||||||
// action_log: Entity<ActionLog>,
|
// action_log: Entity<ActionLog>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Thread {
|
impl Thread {
|
||||||
pub fn new(templates: Arc<Templates>) -> Self {
|
pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
completion_mode: CompletionMode::Normal,
|
completion_mode: CompletionMode::Normal,
|
||||||
|
@ -50,6 +51,7 @@ impl Thread {
|
||||||
running_turn: None,
|
running_turn: None,
|
||||||
tools: BTreeMap::default(),
|
tools: BTreeMap::default(),
|
||||||
templates,
|
templates,
|
||||||
|
selected_model: default_model,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue