From 0f395df9a8f8f8dcceeed3f5cd98e7ba5014cc1d Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 28 Jul 2025 18:02:21 -0300 Subject: [PATCH 01/25] Update to new schema --- Cargo.lock | 4 +--- Cargo.toml | 2 +- crates/acp_thread/src/acp_thread.rs | 14 +++++++------- crates/agent_servers/src/codex.rs | 12 ++++-------- crates/agent_servers/src/mcp_server.rs | 7 ++++--- 5 files changed, 17 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8d9a622655..eb034f4cc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,9 +138,7 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b" +version = "0.0.12" dependencies = [ "schemars", "serde", diff --git a/Cargo.toml b/Cargo.toml index 16ace7dee0..d733f2242e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.11" +agent-client-protocol = {path="../agent-client-protocol"} aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d572992c54..d02f2d6bb6 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -391,7 +391,7 @@ impl ToolCallContent { cx: &mut App, ) -> Self { match content { - acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock { + acp::ToolCallContent::Content { content } => Self::ContentBlock { content: ContentBlock::new(content, &language_registry, cx), }, acp::ToolCallContent::Diff { diff } => Self::Diff { @@ -682,14 +682,14 @@ impl AcpThread { cx: &mut Context, ) -> Result<()> { match update { - acp::SessionUpdate::UserMessage(content_block) => { - self.push_user_content_block(content_block, cx); + acp::SessionUpdate::UserMessageChunk { content } => { + self.push_user_content_block(content, cx); } - acp::SessionUpdate::AgentMessageChunk(content_block) => { - self.push_assistant_content_block(content_block, false, cx); + acp::SessionUpdate::AgentMessageChunk { content } => { + self.push_assistant_content_block(content, false, cx); } - acp::SessionUpdate::AgentThoughtChunk(content_block) => { - self.push_assistant_content_block(content_block, true, cx); + acp::SessionUpdate::AgentThoughtChunk { content } => { + self.push_assistant_content_block(content, true, cx); } acp::SessionUpdate::ToolCall(tool_call) => { self.upsert_tool_call(tool_call, cx); diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs index b10ce9cf54..d40aebfbd7 100644 --- a/crates/agent_servers/src/codex.rs +++ b/crates/agent_servers/src/codex.rs @@ -73,7 +73,7 @@ impl AgentServer for Codex { client .client() .context("Failed to subscribe")? - .on_notification(acp::SESSION_UPDATE_METHOD_NAME, { + .on_notification(acp::AGENT_METHODS.session_update, { move |notification, _cx| { let notification_tx = notification_tx.clone(); log::trace!( @@ -149,13 +149,9 @@ impl AgentConnection for CodexConnection { let response = client .request::(context_server::types::CallToolParams { - name: acp::NEW_SESSION_TOOL_NAME.into(), + name: acp::AGENT_METHODS.new_session.into(), arguments: Some(serde_json::to_value(acp::NewSessionArguments { - mcp_servers: [( - mcp_server::SERVER_NAME.to_string(), - mcp_server.server_config()?, - )] - .into(), + mcp_servers: vec![mcp_server.server_config()?], client_tools: acp::ClientTools { request_permission: Some(acp::McpToolId { mcp_server: mcp_server::SERVER_NAME.into(), @@ -227,7 +223,7 @@ impl AgentConnection for CodexConnection { let result = client .request_with::( context_server::types::CallToolParams { - name: acp::PROMPT_TOOL_NAME.into(), + name: acp::AGENT_METHODS.prompt.into(), arguments: Some(serde_json::to_value(params)?), meta: None, }, diff --git a/crates/agent_servers/src/mcp_server.rs b/crates/agent_servers/src/mcp_server.rs index 055b89dfe2..ec655800ed 100644 --- a/crates/agent_servers/src/mcp_server.rs +++ b/crates/agent_servers/src/mcp_server.rs @@ -37,7 +37,7 @@ impl ZedMcpServer { Ok(Self { server: mcp_server }) } - pub fn server_config(&self) -> Result { + pub fn server_config(&self) -> Result { #[cfg(not(test))] let zed_path = anyhow::Context::context( std::env::current_exe(), @@ -47,13 +47,14 @@ impl ZedMcpServer { #[cfg(test)] let zed_path = crate::e2e_tests::get_zed_path(); - Ok(acp::McpServerConfig { + Ok(acp::McpServer { + name: SERVER_NAME.into(), command: zed_path, args: vec![ "--nc".into(), self.server.socket_path().display().to_string(), ], - env: None, + env: vec![], }) } From ced3d09f10c1d64fa8dd31065724e4b12b69e9ee Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 28 Jul 2025 18:32:30 -0300 Subject: [PATCH 02/25] Extract acp_connection --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/acp_thread/src/acp_thread.rs | 3 +- crates/acp_thread/src/connection.rs | 2 - crates/acp_thread/src/old_acp_support.rs | 6 +- crates/agent_servers/acp | 0 crates/agent_servers/src/acp_connection.rs | 256 +++++++++++++++++++++ crates/agent_servers/src/agent_servers.rs | 1 + crates/agent_servers/src/claude.rs | 9 +- crates/agent_servers/src/codex.rs | 255 +------------------- crates/agent_servers/src/gemini.rs | 182 +++------------ 11 files changed, 305 insertions(+), 415 deletions(-) create mode 100644 crates/agent_servers/acp create mode 100644 crates/agent_servers/src/acp_connection.rs diff --git a/Cargo.lock b/Cargo.lock index eb034f4cc3..e5b1be5a8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,7 +138,9 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.12" +version = "0.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4255a06cc2414033d1fe4baf1968bcc8f16d7e5814f272b97779b5806d129142" dependencies = [ "schemars", "serde", diff --git a/Cargo.toml b/Cargo.toml index d733f2242e..81da82cbb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = {path="../agent-client-protocol"} +agent-client-protocol = "0.0.13" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d02f2d6bb6..1c4b0ec06f 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -616,6 +616,7 @@ impl Error for LoadError {} impl AcpThread { pub fn new( + title: impl Into, connection: Rc, project: Entity, session_id: acp::SessionId, @@ -628,7 +629,7 @@ impl AcpThread { shared_buffers: Default::default(), entries: Default::default(), plan: Default::default(), - title: connection.name().into(), + title: title.into(), project, send_task: None, connection, diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 5b25b71863..97161a19c0 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -9,8 +9,6 @@ use ui::App; use crate::AcpThread; pub trait AgentConnection { - fn name(&self) -> &'static str; - fn new_thread( self: Rc, project: Entity, diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs index 44cd00348f..d7ef1b73da 100644 --- a/crates/acp_thread/src/old_acp_support.rs +++ b/crates/acp_thread/src/old_acp_support.rs @@ -367,10 +367,6 @@ pub struct OldAcpAgentConnection { } impl AgentConnection for OldAcpAgentConnection { - fn name(&self) -> &'static str { - self.name - } - fn new_thread( self: Rc, project: Entity, @@ -394,7 +390,7 @@ impl AgentConnection for OldAcpAgentConnection { cx.update(|cx| { let thread = cx.new(|cx| { let session_id = acp::SessionId("acp-old-no-id".into()); - AcpThread::new(self.clone(), project, session_id, cx) + AcpThread::new("Gemini", self.clone(), project, session_id, cx) }); thread }) diff --git a/crates/agent_servers/acp b/crates/agent_servers/acp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs new file mode 100644 index 0000000000..9139d62c38 --- /dev/null +++ b/crates/agent_servers/src/acp_connection.rs @@ -0,0 +1,256 @@ +use agent_client_protocol as acp; +use anyhow::anyhow; +use collections::HashMap; +use context_server::listener::McpServerTool; +use context_server::types::requests; +use context_server::{ContextServer, ContextServerCommand, ContextServerId}; +use futures::channel::{mpsc, oneshot}; +use project::Project; +use smol::stream::StreamExt as _; +use std::cell::RefCell; +use std::rc::Rc; +use std::{path::Path, sync::Arc}; +use util::ResultExt; + +use anyhow::{Context, Result}; +use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; + +use crate::mcp_server::ZedMcpServer; +use crate::{AgentServerCommand, mcp_server}; +use acp_thread::{AcpThread, AgentConnection}; + +pub struct AcpConnection { + server_name: &'static str, + client: Arc, + sessions: Rc>>, + _notification_handler_task: Task<()>, +} + +impl AcpConnection { + pub async fn stdio( + server_name: &'static str, + command: AgentServerCommand, + cx: &mut AsyncApp, + ) -> Result { + let client: Arc = ContextServer::stdio( + ContextServerId(format!("{}-mcp-server", server_name).into()), + ContextServerCommand { + path: command.path, + args: command.args, + env: command.env, + }, + ) + .into(); + ContextServer::start(client.clone(), cx).await?; + + let (notification_tx, mut notification_rx) = mpsc::unbounded(); + client + .client() + .context("Failed to subscribe")? + .on_notification(acp::AGENT_METHODS.session_update, { + move |notification, _cx| { + let notification_tx = notification_tx.clone(); + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(notification) = + serde_json::from_value::(notification).log_err() + { + notification_tx.unbounded_send(notification).ok(); + } + } + }); + + let sessions = Rc::new(RefCell::new(HashMap::default())); + + let notification_handler_task = cx.spawn({ + let sessions = sessions.clone(); + async move |cx| { + while let Some(notification) = notification_rx.next().await { + Self::handle_session_notification(notification, sessions.clone(), cx) + } + } + }); + + Ok(Self { + server_name, + client, + sessions, + _notification_handler_task: notification_handler_task, + }) + } + + pub fn handle_session_notification( + notification: acp::SessionNotification, + threads: Rc>>, + cx: &mut AsyncApp, + ) { + let threads = threads.borrow(); + let Some(thread) = threads + .get(¬ification.session_id) + .and_then(|session| session.thread.upgrade()) + else { + log::error!( + "Thread not found for session ID: {}", + notification.session_id + ); + return; + }; + + thread + .update(cx, |thread, cx| { + thread.handle_session_update(notification.update, cx) + }) + .log_err(); + } +} + +pub struct AcpSession { + thread: WeakEntity, + cancel_tx: Option>, + _mcp_server: ZedMcpServer, +} + +impl AgentConnection for AcpConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let client = self.client.client(); + let sessions = self.sessions.clone(); + let cwd = cwd.to_path_buf(); + cx.spawn(async move |cx| { + let client = client.context("MCP server is not initialized yet")?; + let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); + + let mcp_server = ZedMcpServer::new(thread_rx, cx).await?; + + let response = client + .request::(context_server::types::CallToolParams { + name: acp::AGENT_METHODS.new_session.into(), + arguments: Some(serde_json::to_value(acp::NewSessionArguments { + mcp_servers: vec![mcp_server.server_config()?], + client_tools: acp::ClientTools { + request_permission: Some(acp::McpToolId { + mcp_server: mcp_server::SERVER_NAME.into(), + tool_name: mcp_server::RequestPermissionTool::NAME.into(), + }), + read_text_file: Some(acp::McpToolId { + mcp_server: mcp_server::SERVER_NAME.into(), + tool_name: mcp_server::ReadTextFileTool::NAME.into(), + }), + write_text_file: Some(acp::McpToolId { + mcp_server: mcp_server::SERVER_NAME.into(), + tool_name: mcp_server::WriteTextFileTool::NAME.into(), + }), + }, + cwd, + })?), + meta: None, + }) + .await?; + + if response.is_error.unwrap_or_default() { + return Err(anyhow!(response.text_contents())); + } + + let result = serde_json::from_value::( + response.structured_content.context("Empty response")?, + )?; + + let thread = cx.new(|cx| { + AcpThread::new( + self.server_name, + self.clone(), + project, + result.session_id.clone(), + cx, + ) + })?; + + thread_tx.send(thread.downgrade())?; + + let session = AcpSession { + thread: thread.downgrade(), + cancel_tx: None, + _mcp_server: mcp_server, + }; + sessions.borrow_mut().insert(result.session_id, session); + + Ok(thread) + }) + } + + fn authenticate(&self, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow!("Authentication not supported"))) + } + + fn prompt( + &self, + params: agent_client_protocol::PromptArguments, + cx: &mut App, + ) -> Task> { + let client = self.client.client(); + let sessions = self.sessions.clone(); + + cx.foreground_executor().spawn(async move { + let client = client.context("MCP server is not initialized yet")?; + + let (new_cancel_tx, cancel_rx) = oneshot::channel(); + { + let mut sessions = sessions.borrow_mut(); + let session = sessions + .get_mut(¶ms.session_id) + .context("Session not found")?; + session.cancel_tx.replace(new_cancel_tx); + } + + let result = client + .request_with::( + context_server::types::CallToolParams { + name: acp::AGENT_METHODS.prompt.into(), + arguments: Some(serde_json::to_value(params)?), + meta: None, + }, + Some(cancel_rx), + None, + ) + .await; + + if let Err(err) = &result + && err.is::() + { + return Ok(()); + } + + let response = result?; + + if response.is_error.unwrap_or_default() { + return Err(anyhow!(response.text_contents())); + } + + Ok(()) + }) + } + + fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) { + let mut sessions = self.sessions.borrow_mut(); + + if let Some(cancel_tx) = sessions + .get_mut(session_id) + .and_then(|session| session.cancel_tx.take()) + { + cancel_tx.send(()).ok(); + } + } +} + +impl Drop for AcpConnection { + fn drop(&mut self) { + self.client.stop().log_err(); + } +} diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 212bb74d8a..6a031a190e 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,3 +1,4 @@ +mod acp_connection; mod claude; mod codex; mod gemini; diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 6565786204..590da69cd8 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -70,10 +70,6 @@ struct ClaudeAgentConnection { } impl AgentConnection for ClaudeAgentConnection { - fn name(&self) -> &'static str { - ClaudeCode.name() - } - fn new_thread( self: Rc, project: Entity, @@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection { } }); - let thread = - cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?; + let thread = cx.new(|cx| { + AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) + })?; thread_tx.send(thread.downgrade())?; diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs index d40aebfbd7..1909781efa 100644 --- a/crates/agent_servers/src/codex.rs +++ b/crates/agent_servers/src/codex.rs @@ -1,24 +1,14 @@ -use agent_client_protocol as acp; -use anyhow::anyhow; -use collections::HashMap; -use context_server::listener::McpServerTool; -use context_server::types::requests; -use context_server::{ContextServer, ContextServerCommand, ContextServerId}; -use futures::channel::{mpsc, oneshot}; use project::Project; use settings::SettingsStore; -use smol::stream::StreamExt as _; -use std::cell::RefCell; +use std::path::Path; use std::rc::Rc; -use std::{path::Path, sync::Arc}; -use util::ResultExt; -use anyhow::{Context, Result}; -use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use anyhow::Result; +use gpui::{App, Entity, Task}; -use crate::mcp_server::ZedMcpServer; -use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server}; -use acp_thread::{AcpThread, AgentConnection}; +use crate::acp_connection::AcpConnection; +use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; +use acp_thread::AgentConnection; #[derive(Clone)] pub struct Codex; @@ -47,6 +37,7 @@ impl AgentServer for Codex { cx: &mut App, ) -> Task>> { let project = project.clone(); + let server_name = self.name(); cx.spawn(async move |cx| { let settings = cx.read_global(|settings: &SettingsStore, _| { settings.get::(None).codex.clone() @@ -58,240 +49,12 @@ impl AgentServer for Codex { anyhow::bail!("Failed to find codex binary"); }; - let client: Arc = ContextServer::stdio( - ContextServerId("codex-mcp-server".into()), - ContextServerCommand { - path: command.path, - args: command.args, - env: command.env, - }, - ) - .into(); - ContextServer::start(client.clone(), cx).await?; - - let (notification_tx, mut notification_rx) = mpsc::unbounded(); - client - .client() - .context("Failed to subscribe")? - .on_notification(acp::AGENT_METHODS.session_update, { - move |notification, _cx| { - let notification_tx = notification_tx.clone(); - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(notification) = - serde_json::from_value::(notification) - .log_err() - { - notification_tx.unbounded_send(notification).ok(); - } - } - }); - - let sessions = Rc::new(RefCell::new(HashMap::default())); - - let notification_handler_task = cx.spawn({ - let sessions = sessions.clone(); - async move |cx| { - while let Some(notification) = notification_rx.next().await { - CodexConnection::handle_session_notification( - notification, - sessions.clone(), - cx, - ) - } - } - }); - - let connection = CodexConnection { - client, - sessions, - _notification_handler_task: notification_handler_task, - }; - Ok(Rc::new(connection) as _) + let conn = AcpConnection::stdio(server_name, command, cx).await?; + Ok(Rc::new(conn) as _) }) } } -struct CodexConnection { - client: Arc, - sessions: Rc>>, - _notification_handler_task: Task<()>, -} - -struct CodexSession { - thread: WeakEntity, - cancel_tx: Option>, - _mcp_server: ZedMcpServer, -} - -impl AgentConnection for CodexConnection { - fn name(&self) -> &'static str { - "Codex" - } - - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut AsyncApp, - ) -> Task>> { - let client = self.client.client(); - let sessions = self.sessions.clone(); - let cwd = cwd.to_path_buf(); - cx.spawn(async move |cx| { - let client = client.context("MCP server is not initialized yet")?; - let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); - - let mcp_server = ZedMcpServer::new(thread_rx, cx).await?; - - let response = client - .request::(context_server::types::CallToolParams { - name: acp::AGENT_METHODS.new_session.into(), - arguments: Some(serde_json::to_value(acp::NewSessionArguments { - mcp_servers: vec![mcp_server.server_config()?], - client_tools: acp::ClientTools { - request_permission: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::RequestPermissionTool::NAME.into(), - }), - read_text_file: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::ReadTextFileTool::NAME.into(), - }), - write_text_file: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::WriteTextFileTool::NAME.into(), - }), - }, - cwd, - })?), - meta: None, - }) - .await?; - - if response.is_error.unwrap_or_default() { - return Err(anyhow!(response.text_contents())); - } - - let result = serde_json::from_value::( - response.structured_content.context("Empty response")?, - )?; - - let thread = - cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?; - - thread_tx.send(thread.downgrade())?; - - let session = CodexSession { - thread: thread.downgrade(), - cancel_tx: None, - _mcp_server: mcp_server, - }; - sessions.borrow_mut().insert(result.session_id, session); - - Ok(thread) - }) - } - - fn authenticate(&self, _cx: &mut App) -> Task> { - Task::ready(Err(anyhow!("Authentication not supported"))) - } - - fn prompt( - &self, - params: agent_client_protocol::PromptArguments, - cx: &mut App, - ) -> Task> { - let client = self.client.client(); - let sessions = self.sessions.clone(); - - cx.foreground_executor().spawn(async move { - let client = client.context("MCP server is not initialized yet")?; - - let (new_cancel_tx, cancel_rx) = oneshot::channel(); - { - let mut sessions = sessions.borrow_mut(); - let session = sessions - .get_mut(¶ms.session_id) - .context("Session not found")?; - session.cancel_tx.replace(new_cancel_tx); - } - - let result = client - .request_with::( - context_server::types::CallToolParams { - name: acp::AGENT_METHODS.prompt.into(), - arguments: Some(serde_json::to_value(params)?), - meta: None, - }, - Some(cancel_rx), - None, - ) - .await; - - if let Err(err) = &result - && err.is::() - { - return Ok(()); - } - - let response = result?; - - if response.is_error.unwrap_or_default() { - return Err(anyhow!(response.text_contents())); - } - - Ok(()) - }) - } - - fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) { - let mut sessions = self.sessions.borrow_mut(); - - if let Some(cancel_tx) = sessions - .get_mut(session_id) - .and_then(|session| session.cancel_tx.take()) - { - cancel_tx.send(()).ok(); - } - } -} - -impl CodexConnection { - pub fn handle_session_notification( - notification: acp::SessionNotification, - threads: Rc>>, - cx: &mut AsyncApp, - ) { - let threads = threads.borrow(); - let Some(thread) = threads - .get(¬ification.session_id) - .and_then(|session| session.thread.upgrade()) - else { - log::error!( - "Thread not found for session ID: {}", - notification.session_id - ); - return; - }; - - thread - .update(cx, |thread, cx| { - thread.handle_session_update(notification.update, cx) - }) - .log_err(); - } -} - -impl Drop for CodexConnection { - fn drop(&mut self) { - self.client.stop().log_err(); - } -} - #[cfg(test)] pub(crate) mod tests { use super::*; diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 8b9fed5777..70c7f8efb5 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,25 +1,18 @@ -use anyhow::anyhow; -use std::cell::RefCell; -use std::path::Path; -use std::rc::Rc; -use util::ResultExt as _; - -use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; -use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; -use agentic_coding_protocol as acp_old; -use anyhow::{Context as _, Result}; -use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; use settings::SettingsStore; -use ui::App; +use std::path::Path; +use std::rc::Rc; -use crate::AllAgentServersSettings; +use anyhow::Result; +use gpui::{App, Entity, Task}; + +use crate::acp_connection::AcpConnection; +use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; +use acp_thread::AgentConnection; #[derive(Clone)] pub struct Gemini; -const ACP_ARG: &str = "--experimental-acp"; - impl AgentServer for Gemini { fn name(&self) -> &'static str { "Gemini" @@ -39,166 +32,49 @@ impl AgentServer for Gemini { fn connect( &self, - root_dir: &Path, + _root_dir: &Path, project: &Entity, cx: &mut App, ) -> Task>> { - let root_dir = root_dir.to_path_buf(); let project = project.clone(); - let this = self.clone(); - let name = self.name(); - + let server_name = self.name(); cx.spawn(async move |cx| { - let command = this.command(&project, cx).await?; + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() + })?; - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; + let Some(command) = AgentServerCommand::resolve( + "gemini", + &["--experimental-mcp"], + settings, + &project, + cx, + ) + .await + else { + anyhow::bail!("Failed to find gemini binary"); + }; - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - let foreground_executor = cx.foreground_executor().clone(); - - let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); - - let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(AgentServerVersion::Unsupported { - error_message, - upgrade_message, - upgrade_command, - }) = this.version(&command).await.log_err() - { - Err(anyhow!(LoadError::Unsupported { - error_message, - upgrade_message, - upgrade_command - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } - } - }; - drop(io_task); - result - }); - - let connection: Rc = Rc::new(OldAcpAgentConnection { - name, - connection, - child_status, - }); - - Ok(connection) + let conn = AcpConnection::stdio(server_name, command, cx).await?; + Ok(Rc::new(conn) as _) }) } } -impl Gemini { - async fn command( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).gemini.clone() - })?; - - if let Some(command) = - AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await - { - return Ok(command); - }; - - let (fs, node_runtime) = project.update(cx, |project, _| { - (project.fs().clone(), project.node_runtime().cloned()) - })?; - let node_runtime = node_runtime.context("gemini not found on path")?; - - let directory = ::paths::agent_servers_dir().join("gemini"); - fs.create_dir(&directory).await?; - node_runtime - .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) - .await?; - let path = directory.join("node_modules/.bin/gemini"); - - Ok(AgentServerCommand { - path, - args: vec![ACP_ARG.into()], - env: None, - }) - } - - async fn version(&self, command: &AgentServerCommand) -> Result { - let version_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--version") - .kill_on_drop(true) - .output(); - - let help_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--help") - .kill_on_drop(true) - .output(); - - let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; - - let current_version = String::from_utf8(version_output?.stdout)?; - let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); - - if supported { - Ok(AgentServerVersion::Supported) - } else { - Ok(AgentServerVersion::Unsupported { - error_message: format!( - "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", - current_version - ).into(), - upgrade_message: "Upgrade Gemini to Latest".into(), - upgrade_command: "npm install -g @google/gemini-cli@latest".into(), - }) - } - } -} - #[cfg(test)] pub(crate) mod tests { use super::*; use crate::AgentServerCommand; use std::path::Path; - crate::common_e2e_tests!(Gemini, allow_option_id = "0"); + crate::common_e2e_tests!(Gemini, allow_option_id = "allow"); pub fn local_command() -> AgentServerCommand { - let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) - .join("../../../gemini-cli/packages/cli") - .to_string_lossy() - .to_string(); + let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini/packages/cli"); AgentServerCommand { path: "node".into(), - args: vec![cli_path, ACP_ARG.into()], + args: vec![cli_path.to_string_lossy().to_string()], env: None, } } From b48faddaf416f2597bbec9bf30823017cd1d1931 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 28 Jul 2025 18:45:05 -0300 Subject: [PATCH 03/25] Restore gemini change --- crates/agent_servers/src/gemini.rs | 180 ++++++++++++++++++++++++----- 1 file changed, 152 insertions(+), 28 deletions(-) diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 70c7f8efb5..8b9fed5777 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,18 +1,25 @@ -use project::Project; -use settings::SettingsStore; +use anyhow::anyhow; +use std::cell::RefCell; use std::path::Path; use std::rc::Rc; +use util::ResultExt as _; -use anyhow::Result; -use gpui::{App, Entity, Task}; +use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; +use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; +use agentic_coding_protocol as acp_old; +use anyhow::{Context as _, Result}; +use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use project::Project; +use settings::SettingsStore; +use ui::App; -use crate::acp_connection::AcpConnection; -use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::AgentConnection; +use crate::AllAgentServersSettings; #[derive(Clone)] pub struct Gemini; +const ACP_ARG: &str = "--experimental-acp"; + impl AgentServer for Gemini { fn name(&self) -> &'static str { "Gemini" @@ -32,49 +39,166 @@ impl AgentServer for Gemini { fn connect( &self, - _root_dir: &Path, + root_dir: &Path, project: &Entity, cx: &mut App, ) -> Task>> { + let root_dir = root_dir.to_path_buf(); let project = project.clone(); - let server_name = self.name(); + let this = self.clone(); + let name = self.name(); + cx.spawn(async move |cx| { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).gemini.clone() - })?; + let command = this.command(&project, cx).await?; - let Some(command) = AgentServerCommand::resolve( - "gemini", - &["--experimental-mcp"], - settings, - &project, - cx, - ) - .await - else { - anyhow::bail!("Failed to find gemini binary"); - }; + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; - let conn = AcpConnection::stdio(server_name, command, cx).await?; - Ok(Rc::new(conn) as _) + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + let foreground_executor = cx.foreground_executor().clone(); + + let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => { + if let Some(AgentServerVersion::Unsupported { + error_message, + upgrade_message, + upgrade_command, + }) = this.version(&command).await.log_err() + { + Err(anyhow!(LoadError::Unsupported { + error_message, + upgrade_message, + upgrade_command + })) + } else { + Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) + } + } + }; + drop(io_task); + result + }); + + let connection: Rc = Rc::new(OldAcpAgentConnection { + name, + connection, + child_status, + }); + + Ok(connection) }) } } +impl Gemini { + async fn command( + &self, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() + })?; + + if let Some(command) = + AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await + { + return Ok(command); + }; + + let (fs, node_runtime) = project.update(cx, |project, _| { + (project.fs().clone(), project.node_runtime().cloned()) + })?; + let node_runtime = node_runtime.context("gemini not found on path")?; + + let directory = ::paths::agent_servers_dir().join("gemini"); + fs.create_dir(&directory).await?; + node_runtime + .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) + .await?; + let path = directory.join("node_modules/.bin/gemini"); + + Ok(AgentServerCommand { + path, + args: vec![ACP_ARG.into()], + env: None, + }) + } + + async fn version(&self, command: &AgentServerCommand) -> Result { + let version_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--version") + .kill_on_drop(true) + .output(); + + let help_fut = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .arg("--help") + .kill_on_drop(true) + .output(); + + let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; + + let current_version = String::from_utf8(version_output?.stdout)?; + let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); + + if supported { + Ok(AgentServerVersion::Supported) + } else { + Ok(AgentServerVersion::Unsupported { + error_message: format!( + "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", + current_version + ).into(), + upgrade_message: "Upgrade Gemini to Latest".into(), + upgrade_command: "npm install -g @google/gemini-cli@latest".into(), + }) + } + } +} + #[cfg(test)] pub(crate) mod tests { use super::*; use crate::AgentServerCommand; use std::path::Path; - crate::common_e2e_tests!(Gemini, allow_option_id = "allow"); + crate::common_e2e_tests!(Gemini, allow_option_id = "0"); pub fn local_command() -> AgentServerCommand { - let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini/packages/cli"); + let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../../../gemini-cli/packages/cli") + .to_string_lossy() + .to_string(); AgentServerCommand { path: "node".into(), - args: vec![cli_path.to_string_lossy().to_string()], + args: vec![cli_path, ACP_ARG.into()], env: None, } } From 912ab505b2677a90de74c73d6258cd4f709aee2d Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 28 Jul 2025 20:04:32 -0300 Subject: [PATCH 04/25] Connect to gemini over MCP --- crates/agent_servers/src/codex.rs | 1 + crates/agent_servers/src/gemini.rs | 159 ++++------------------------- 2 files changed, 19 insertions(+), 141 deletions(-) diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs index 1909781efa..06d8d10a91 100644 --- a/crates/agent_servers/src/codex.rs +++ b/crates/agent_servers/src/codex.rs @@ -48,6 +48,7 @@ impl AgentServer for Codex { else { anyhow::bail!("Failed to find codex binary"); }; + // todo! check supported version let conn = AcpConnection::stdio(server_name, command, cx).await?; Ok(Rc::new(conn) as _) diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 8b9fed5777..07c4e1b539 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,14 +1,10 @@ -use anyhow::anyhow; -use std::cell::RefCell; use std::path::Path; use std::rc::Rc; -use util::ResultExt as _; -use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; -use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; -use agentic_coding_protocol as acp_old; -use anyhow::{Context as _, Result}; -use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; +use crate::{AgentServer, AgentServerCommand, acp_connection::AcpConnection}; +use acp_thread::AgentConnection; +use anyhow::Result; +use gpui::{Entity, Task}; use project::Project; use settings::SettingsStore; use ui::App; @@ -39,149 +35,30 @@ impl AgentServer for Gemini { fn connect( &self, - root_dir: &Path, + _root_dir: &Path, project: &Entity, cx: &mut App, ) -> Task>> { - let root_dir = root_dir.to_path_buf(); let project = project.clone(); - let this = self.clone(); - let name = self.name(); - + let server_name = self.name(); cx.spawn(async move |cx| { - let command = this.command(&project, cx).await?; + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).gemini.clone() + })?; - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; + let Some(command) = + AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await + else { + anyhow::bail!("Failed to find gemini binary"); + }; + // todo! check supported version - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - let foreground_executor = cx.foreground_executor().clone(); - - let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); - - let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(AgentServerVersion::Unsupported { - error_message, - upgrade_message, - upgrade_command, - }) = this.version(&command).await.log_err() - { - Err(anyhow!(LoadError::Unsupported { - error_message, - upgrade_message, - upgrade_command - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } - } - }; - drop(io_task); - result - }); - - let connection: Rc = Rc::new(OldAcpAgentConnection { - name, - connection, - child_status, - }); - - Ok(connection) + let conn = AcpConnection::stdio(server_name, command, cx).await?; + Ok(Rc::new(conn) as _) }) } } -impl Gemini { - async fn command( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).gemini.clone() - })?; - - if let Some(command) = - AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await - { - return Ok(command); - }; - - let (fs, node_runtime) = project.update(cx, |project, _| { - (project.fs().clone(), project.node_runtime().cloned()) - })?; - let node_runtime = node_runtime.context("gemini not found on path")?; - - let directory = ::paths::agent_servers_dir().join("gemini"); - fs.create_dir(&directory).await?; - node_runtime - .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")]) - .await?; - let path = directory.join("node_modules/.bin/gemini"); - - Ok(AgentServerCommand { - path, - args: vec![ACP_ARG.into()], - env: None, - }) - } - - async fn version(&self, command: &AgentServerCommand) -> Result { - let version_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--version") - .kill_on_drop(true) - .output(); - - let help_fut = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .arg("--help") - .kill_on_drop(true) - .output(); - - let (version_output, help_output) = futures::future::join(version_fut, help_fut).await; - - let current_version = String::from_utf8(version_output?.stdout)?; - let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG); - - if supported { - Ok(AgentServerVersion::Supported) - } else { - Ok(AgentServerVersion::Unsupported { - error_message: format!( - "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).", - current_version - ).into(), - upgrade_message: "Upgrade Gemini to Latest".into(), - upgrade_command: "npm install -g @google/gemini-cli@latest".into(), - }) - } - } -} - #[cfg(test)] pub(crate) mod tests { use super::*; @@ -198,7 +75,7 @@ pub(crate) mod tests { AgentServerCommand { path: "node".into(), - args: vec![cli_path, ACP_ARG.into()], + args: vec![cli_path], env: None, } } From 254c6be42b69048bc5e1689e5438d1880fa8034c Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Tue, 29 Jul 2025 10:12:57 +0200 Subject: [PATCH 05/25] Fix broken test --- crates/acp_thread/src/acp_thread.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 1c4b0ec06f..841d320796 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1601,6 +1601,7 @@ mod tests { }; AcpThread::new( + "test", Rc::new(connection), project, acp::SessionId("test".into()), From 6656403ce85602e159262765654269297c95e630 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Tue, 29 Jul 2025 21:15:00 -0300 Subject: [PATCH 06/25] Auth WIP --- Cargo.lock | 2 - Cargo.toml | 2 +- crates/acp_thread/src/acp_thread.rs | 4 - crates/acp_thread/src/connection.rs | 8 +- crates/acp_thread/src/old_acp_support.rs | 15 +++- crates/agent_servers/src/acp_connection.rs | 100 ++++++++++++++++----- crates/agent_servers/src/claude.rs | 10 ++- crates/agent_ui/src/acp/thread_view.rs | 55 ++++++++---- 8 files changed, 140 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1682b80a8c..f68136d978 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -139,8 +139,6 @@ dependencies = [ [[package]] name = "agent-client-protocol" version = "0.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4255a06cc2414033d1fe4baf1968bcc8f16d7e5814f272b97779b5806d129142" dependencies = [ "schemars", "serde", diff --git a/Cargo.toml b/Cargo.toml index 81da82cbb7..d733f2242e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.13" +agent-client-protocol = {path="../agent-client-protocol"} aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 841d320796..3b9f0842bd 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -958,10 +958,6 @@ impl AcpThread { cx.notify(); } - pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future> { - self.connection.authenticate(cx) - } - #[cfg(any(test, feature = "test-support"))] pub fn send_raw( &mut self, diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 97161a19c0..2e7deaf7df 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,6 +1,6 @@ -use std::{path::Path, rc::Rc}; +use std::{cell::Ref, path::Path, rc::Rc}; -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp}; use anyhow::Result; use gpui::{AsyncApp, Entity, Task}; use project::Project; @@ -16,7 +16,9 @@ pub trait AgentConnection { cx: &mut AsyncApp, ) -> Task>>; - fn authenticate(&self, cx: &mut App) -> Task>; + fn state(&self) -> Ref<'_, acp::AgentState>; + + fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task>; diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs index d7ef1b73da..4d06f81d06 100644 --- a/crates/acp_thread/src/old_acp_support.rs +++ b/crates/acp_thread/src/old_acp_support.rs @@ -5,7 +5,13 @@ use anyhow::{Context as _, Result}; use futures::channel::oneshot; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; -use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc}; +use std::{ + cell::{Ref, RefCell}, + error::Error, + fmt, + path::Path, + rc::Rc, +}; use ui::App; use crate::{AcpThread, AgentConnection}; @@ -364,6 +370,7 @@ pub struct OldAcpAgentConnection { pub name: &'static str, pub connection: acp_old::AgentConnection, pub child_status: Task>, + pub agent_state: Rc>, } impl AgentConnection for OldAcpAgentConnection { @@ -397,7 +404,11 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn authenticate(&self, cx: &mut App) -> Task> { + fn state(&self) -> Ref<'_, acp::AgentState> { + self.agent_state.borrow() + } + + fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { let task = self .connection .request_any(acp_old::AuthenticateParams.into_any()); diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs index 9139d62c38..96067fe520 100644 --- a/crates/agent_servers/src/acp_connection.rs +++ b/crates/agent_servers/src/acp_connection.rs @@ -7,10 +7,10 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use futures::channel::{mpsc, oneshot}; use project::Project; use smol::stream::StreamExt as _; -use std::cell::RefCell; +use std::cell::{Ref, RefCell}; use std::rc::Rc; use std::{path::Path, sync::Arc}; -use util::ResultExt; +use util::{ResultExt, TryFutureExt}; use anyhow::{Context, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; @@ -20,10 +20,12 @@ use crate::{AgentServerCommand, mcp_server}; use acp_thread::{AcpThread, AgentConnection}; pub struct AcpConnection { + agent_state: Rc>, server_name: &'static str, client: Arc, sessions: Rc>>, - _notification_handler_task: Task<()>, + _agent_state_task: Task<()>, + _session_update_task: Task<()>, } impl AcpConnection { @@ -43,29 +45,55 @@ impl AcpConnection { .into(); ContextServer::start(client.clone(), cx).await?; - let (notification_tx, mut notification_rx) = mpsc::unbounded(); - client - .client() - .context("Failed to subscribe")? - .on_notification(acp::AGENT_METHODS.session_update, { - move |notification, _cx| { - let notification_tx = notification_tx.clone(); - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); + let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default()); + let mcp_client = client.client().context("Failed to subscribe")?; - if let Some(notification) = - serde_json::from_value::(notification).log_err() - { - notification_tx.unbounded_send(notification).ok(); - } + mcp_client.on_notification(acp::AGENT_METHODS.agent_state, { + move |notification, _cx| { + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(state) = + serde_json::from_value::(notification).log_err() + { + state_tx.send(state).log_err(); } - }); + } + }); + + let (notification_tx, mut notification_rx) = mpsc::unbounded(); + mcp_client.on_notification(acp::AGENT_METHODS.session_update, { + move |notification, _cx| { + let notification_tx = notification_tx.clone(); + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(notification) = + serde_json::from_value::(notification).log_err() + { + notification_tx.unbounded_send(notification).ok(); + } + } + }); let sessions = Rc::new(RefCell::new(HashMap::default())); + let initial_state = state_rx.recv().await?; + let agent_state = Rc::new(RefCell::new(initial_state)); - let notification_handler_task = cx.spawn({ + let agent_state_task = cx.foreground_executor().spawn({ + let agent_state = agent_state.clone(); + async move { + while let Some(state) = state_rx.recv().log_err().await { + agent_state.replace(state); + } + } + }); + + let session_update_handler_task = cx.spawn({ let sessions = sessions.clone(); async move |cx| { while let Some(notification) = notification_rx.next().await { @@ -78,7 +106,9 @@ impl AcpConnection { server_name, client, sessions, - _notification_handler_task: notification_handler_task, + agent_state, + _agent_state_task: agent_state_task, + _session_update_task: session_update_handler_task, }) } @@ -185,8 +215,30 @@ impl AgentConnection for AcpConnection { }) } - fn authenticate(&self, _cx: &mut App) -> Task> { - Task::ready(Err(anyhow!("Authentication not supported"))) + fn state(&self) -> Ref<'_, acp::AgentState> { + self.agent_state.borrow() + } + + fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { + let client = self.client.client(); + cx.foreground_executor().spawn(async move { + let params = acp::AuthenticateArguments { method_id }; + + let response = client + .context("MCP server is not initialized yet")? + .request::(context_server::types::CallToolParams { + name: acp::AGENT_METHODS.authenticate.into(), + arguments: Some(serde_json::to_value(params)?), + meta: None, + }) + .await?; + + if response.is_error.unwrap_or_default() { + Err(anyhow!(response.text_contents())) + } else { + Ok(()) + } + }) } fn prompt( diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 590da69cd8..0f49403a0b 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,7 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; -use std::cell::RefCell; +use std::cell::{Ref, RefCell}; use std::fmt::Display; use std::path::Path; use std::rc::Rc; @@ -58,6 +58,7 @@ impl AgentServer for ClaudeCode { _cx: &mut App, ) -> Task>> { let connection = ClaudeAgentConnection { + agent_state: Default::default(), sessions: Default::default(), }; @@ -66,6 +67,7 @@ impl AgentServer for ClaudeCode { } struct ClaudeAgentConnection { + agent_state: Rc>, sessions: Rc>>, } @@ -183,7 +185,11 @@ impl AgentConnection for ClaudeAgentConnection { }) } - fn authenticate(&self, _cx: &mut App) -> Task> { + fn state(&self) -> Ref<'_, acp::AgentState> { + self.agent_state.borrow() + } + + fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { Task::ready(Err(anyhow!("Authentication not supported"))) } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index e46e1ae3ab..824748a0aa 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -216,6 +216,15 @@ impl AcpThreadView { } }; + if connection.state().needs_authentication { + this.update(cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { connection }; + cx.notify(); + }) + .ok(); + return; + } + let result = match connection .clone() .new_thread(project.clone(), &root_dir, cx) @@ -223,6 +232,7 @@ impl AcpThreadView { { Err(e) => { let mut cx = cx.clone(); + // todo! remove duplication if e.downcast_ref::().is_some() { this.update(&mut cx, |this, cx| { this.thread_state = ThreadState::Unauthenticated { connection }; @@ -640,13 +650,18 @@ impl AcpThreadView { Some(entry.diffs().map(|diff| diff.multibuffer.clone())) } - fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { + fn authenticate( + &mut self, + method: acp::AuthMethodId, + window: &mut Window, + cx: &mut Context, + ) { let ThreadState::Unauthenticated { ref connection } = self.thread_state else { return; }; self.last_error.take(); - let authenticate = connection.authenticate(cx); + let authenticate = connection.authenticate(method, cx); self.auth_task = Some(cx.spawn_in(window, { let project = self.project.clone(); let agent = self.agent.clone(); @@ -2197,22 +2212,26 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::next_history_message)) .on_action(cx.listener(Self::open_agent_diff)) .child(match &self.thread_state { - ThreadState::Unauthenticated { .. } => { - v_flex() - .p_2() - .flex_1() - .items_center() - .justify_center() - .child(self.render_pending_auth_state()) - .child( - h_flex().mt_1p5().justify_center().child( - Button::new("sign-in", format!("Sign in to {}", self.agent.name())) - .on_click(cx.listener(|this, _, window, cx| { - this.authenticate(window, cx) - })), - ), - ) - } + ThreadState::Unauthenticated { connection } => v_flex() + .p_2() + .flex_1() + .items_center() + .justify_center() + .child(self.render_pending_auth_state()) + .child(h_flex().mt_1p5().justify_center().children( + connection.state().auth_methods.iter().map(|method| { + Button::new( + SharedString::from(method.id.0.clone()), + method.label.clone(), + ) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) + }) + }), + )), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), ThreadState::LoadError(e) => v_flex() .p_2() From 81c111510f43241cef933567fc2905051dfc5fa3 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 30 Jul 2025 15:48:40 +0200 Subject: [PATCH 07/25] Refactor handling of ContextServer notifications The notification handler registration is now more explicit, with handlers set up before server initialization to avoid potential race conditions. --- crates/agent_servers/src/acp_connection.rs | 85 +++++++++++---------- crates/context_server/src/client.rs | 14 ++-- crates/context_server/src/context_server.rs | 27 ++++++- crates/context_server/src/protocol.rs | 9 ++- 4 files changed, 79 insertions(+), 56 deletions(-) diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs index 95c09e2c52..5883f6ac45 100644 --- a/crates/agent_servers/src/acp_connection.rs +++ b/crates/agent_servers/src/acp_connection.rs @@ -22,7 +22,7 @@ use acp_thread::{AcpThread, AgentConnection}; pub struct AcpConnection { agent_state: Rc>, server_name: &'static str, - client: Arc, + context_server: Arc, sessions: Rc>>, _agent_state_task: Task<()>, _session_update_task: Task<()>, @@ -35,7 +35,7 @@ impl AcpConnection { working_directory: Option>, cx: &mut AsyncApp, ) -> Result { - let client: Arc = ContextServer::stdio( + let context_server: Arc = ContextServer::stdio( ContextServerId(format!("{}-mcp-server", server_name).into()), ContextServerCommand { path: command.path, @@ -45,42 +45,9 @@ impl AcpConnection { working_directory, ) .into(); - ContextServer::start(client.clone(), cx).await?; let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default()); - let mcp_client = client.client().context("Failed to subscribe")?; - - mcp_client.on_notification(acp::AGENT_METHODS.agent_state, { - move |notification, _cx| { - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(state) = - serde_json::from_value::(notification).log_err() - { - state_tx.send(state).log_err(); - } - } - }); - let (notification_tx, mut notification_rx) = mpsc::unbounded(); - mcp_client.on_notification(acp::AGENT_METHODS.session_update, { - move |notification, _cx| { - let notification_tx = notification_tx.clone(); - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(notification) = - serde_json::from_value::(notification).log_err() - { - notification_tx.unbounded_send(notification).ok(); - } - } - }); let sessions = Rc::new(RefCell::new(HashMap::default())); let initial_state = state_rx.recv().await?; @@ -104,9 +71,47 @@ impl AcpConnection { } }); + context_server + .start_with_handlers( + vec![ + (acp::AGENT_METHODS.agent_state, { + Box::new(move |notification, _cx| { + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(state) = + serde_json::from_value::(notification).log_err() + { + state_tx.send(state).log_err(); + } + }) + }), + (acp::AGENT_METHODS.session_update, { + Box::new(move |notification, _cx| { + let notification_tx = notification_tx.clone(); + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(notification) = + serde_json::from_value::(notification) + .log_err() + { + notification_tx.unbounded_send(notification).ok(); + } + }) + }), + ], + cx, + ) + .await?; + Ok(Self { server_name, - client, + context_server, sessions, agent_state, _agent_state_task: agent_state_task, @@ -152,7 +157,7 @@ impl AgentConnection for AcpConnection { cwd: &Path, cx: &mut AsyncApp, ) -> Task>> { - let client = self.client.client(); + let client = self.context_server.client(); let sessions = self.sessions.clone(); let cwd = cwd.to_path_buf(); cx.spawn(async move |cx| { @@ -222,7 +227,7 @@ impl AgentConnection for AcpConnection { } fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { - let client = self.client.client(); + let client = self.context_server.client(); cx.foreground_executor().spawn(async move { let params = acp::AuthenticateArguments { method_id }; @@ -248,7 +253,7 @@ impl AgentConnection for AcpConnection { params: agent_client_protocol::PromptArguments, cx: &mut App, ) -> Task> { - let client = self.client.client(); + let client = self.context_server.client(); let sessions = self.sessions.clone(); cx.foreground_executor().spawn(async move { @@ -305,6 +310,6 @@ impl AgentConnection for AcpConnection { impl Drop for AcpConnection { fn drop(&mut self) { - self.client.stop().log_err(); + self.context_server.stop().log_err(); } } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 1eb29bbbf9..65283afa87 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -441,14 +441,12 @@ impl Client { Ok(()) } - #[allow(unused)] - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { - self.notification_handlers - .lock() - .insert(method, Box::new(f)); + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { + self.notification_handlers.lock().insert(method, f); } } diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index e76e7972f7..34fa29678d 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -95,8 +95,28 @@ impl ContextServer { self.client.read().clone() } - pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { - let client = match &self.configuration { + pub async fn start(&self, cx: &AsyncApp) -> Result<()> { + self.initialize(self.new_client(cx)?).await + } + + /// Starts the context server, making sure handlers are registered before initialization happens + pub async fn start_with_handlers( + &self, + notification_handlers: Vec<( + &'static str, + Box, + )>, + cx: &AsyncApp, + ) -> Result<()> { + let client = self.new_client(cx)?; + for (method, handler) in notification_handlers { + client.on_notification(method, handler); + } + self.initialize(client).await + } + + fn new_client(&self, cx: &AsyncApp) -> Result { + Ok(match &self.configuration { ContextServerTransport::Stdio(command, working_directory) => Client::stdio( client::ContextServerId(self.id.0.clone()), client::ModelContextServerBinary { @@ -113,8 +133,7 @@ impl ContextServer { transport.clone(), cx.clone(), )?, - }; - self.initialize(client).await + }) } async fn initialize(&self, client: Client) -> Result<()> { diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 9ccbc8a553..5355f20f62 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -115,10 +115,11 @@ impl InitializedContextServerProtocol { self.inner.notify(T::METHOD, params) } - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { self.inner.on_notification(method, f); } } From 738296345eaa880ebda8adf247e9974c7a20c880 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Wed, 30 Jul 2025 11:46:11 -0300 Subject: [PATCH 08/25] Inline tool schemas --- crates/context_server/src/listener.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs index 34e3a9a78c..0e85fb2129 100644 --- a/crates/context_server/src/listener.rs +++ b/crates/context_server/src/listener.rs @@ -83,14 +83,18 @@ impl McpServer { } pub fn add_tool(&mut self, tool: T) { - let output_schema = schemars::schema_for!(T::Output); - let unit_schema = schemars::schema_for!(()); + let mut settings = schemars::generate::SchemaSettings::draft07(); + settings.inline_subschemas = true; + let mut generator = settings.into_generator(); + + let output_schema = generator.root_schema_for::(); + let unit_schema = generator.root_schema_for::(); let registered_tool = RegisteredTool { tool: Tool { name: T::NAME.into(), description: Some(tool.description().into()), - input_schema: schemars::schema_for!(T::Input).into(), + input_schema: generator.root_schema_for::().into(), output_schema: if output_schema == unit_schema { None } else { From 27708143ecac70b963bbe70153d426fe5f061277 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Wed, 30 Jul 2025 13:30:50 -0300 Subject: [PATCH 09/25] Fix auth --- crates/acp_thread/src/acp_thread.rs | 1 - crates/acp_thread/src/connection.rs | 14 +++++- crates/acp_thread/src/old_acp_support.rs | 31 ++++--------- crates/agent_servers/src/acp_connection.rs | 54 +++++++--------------- crates/agent_servers/src/claude.rs | 8 ++-- crates/agent_ui/src/acp/thread_view.rs | 28 ++++------- 6 files changed, 48 insertions(+), 88 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 8aa07da330..bc2f8c756a 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1595,7 +1595,6 @@ mod tests { connection, child_status: io_task, current_thread: thread_rc, - agent_state: Default::default(), }; AcpThread::new( diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 2e7deaf7df..11f1fcc94c 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,4 +1,4 @@ -use std::{cell::Ref, path::Path, rc::Rc}; +use std::{error::Error, fmt, path::Path, rc::Rc}; use agent_client_protocol::{self as acp}; use anyhow::Result; @@ -16,7 +16,7 @@ pub trait AgentConnection { cx: &mut AsyncApp, ) -> Task>>; - fn state(&self) -> Ref<'_, acp::AgentState>; + fn auth_methods(&self) -> Vec; fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; @@ -24,3 +24,13 @@ pub trait AgentConnection { fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); } + +#[derive(Debug)] +pub struct AuthRequired; + +impl Error for AuthRequired {} +impl fmt::Display for AuthRequired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthRequired") + } +} diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs index 718ad0da03..88313e0fd5 100644 --- a/crates/acp_thread/src/old_acp_support.rs +++ b/crates/acp_thread/src/old_acp_support.rs @@ -5,17 +5,11 @@ use anyhow::{Context as _, Result}; use futures::channel::oneshot; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; -use std::{ - cell::{Ref, RefCell}, - error::Error, - fmt, - path::Path, - rc::Rc, -}; +use std::{cell::RefCell, path::Path, rc::Rc}; use ui::App; use util::ResultExt as _; -use crate::{AcpThread, AgentConnection}; +use crate::{AcpThread, AgentConnection, AuthRequired}; #[derive(Clone)] pub struct OldAcpClientDelegate { @@ -357,21 +351,10 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu } } -#[derive(Debug)] -pub struct Unauthenticated; - -impl Error for Unauthenticated {} -impl fmt::Display for Unauthenticated { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Unauthenticated") - } -} - pub struct OldAcpAgentConnection { pub name: &'static str, pub connection: acp_old::AgentConnection, pub child_status: Task>, - pub agent_state: Rc>, pub current_thread: Rc>>, } @@ -394,7 +377,7 @@ impl AgentConnection for OldAcpAgentConnection { let result = acp_old::InitializeParams::response_from_any(result)?; if !result.is_authenticated { - anyhow::bail!(Unauthenticated) + anyhow::bail!(AuthRequired) } cx.update(|cx| { @@ -408,8 +391,12 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn state(&self) -> Ref<'_, acp::AgentState> { - self.agent_state.borrow() + fn auth_methods(&self) -> Vec { + vec![acp::AuthMethod { + id: acp::AuthMethodId("acp-old-no-id".into()), + label: "Log in".into(), + description: None, + }] } fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs index 95c09e2c52..c19a145196 100644 --- a/crates/agent_servers/src/acp_connection.rs +++ b/crates/agent_servers/src/acp_connection.rs @@ -7,24 +7,23 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use futures::channel::{mpsc, oneshot}; use project::Project; use smol::stream::StreamExt as _; -use std::cell::{Ref, RefCell}; +use std::cell::RefCell; use std::rc::Rc; use std::{path::Path, sync::Arc}; -use util::{ResultExt, TryFutureExt}; +use util::ResultExt; use anyhow::{Context, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use crate::mcp_server::ZedMcpServer; use crate::{AgentServerCommand, mcp_server}; -use acp_thread::{AcpThread, AgentConnection}; +use acp_thread::{AcpThread, AgentConnection, AuthRequired}; pub struct AcpConnection { - agent_state: Rc>, + auth_methods: Rc>>, server_name: &'static str, client: Arc, sessions: Rc>>, - _agent_state_task: Task<()>, _session_update_task: Task<()>, } @@ -47,24 +46,8 @@ impl AcpConnection { .into(); ContextServer::start(client.clone(), cx).await?; - let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default()); let mcp_client = client.client().context("Failed to subscribe")?; - mcp_client.on_notification(acp::AGENT_METHODS.agent_state, { - move |notification, _cx| { - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(state) = - serde_json::from_value::(notification).log_err() - { - state_tx.send(state).log_err(); - } - } - }); - let (notification_tx, mut notification_rx) = mpsc::unbounded(); mcp_client.on_notification(acp::AGENT_METHODS.session_update, { move |notification, _cx| { @@ -83,17 +66,6 @@ impl AcpConnection { }); let sessions = Rc::new(RefCell::new(HashMap::default())); - let initial_state = state_rx.recv().await?; - let agent_state = Rc::new(RefCell::new(initial_state)); - - let agent_state_task = cx.foreground_executor().spawn({ - let agent_state = agent_state.clone(); - async move { - while let Some(state) = state_rx.recv().log_err().await { - agent_state.replace(state); - } - } - }); let session_update_handler_task = cx.spawn({ let sessions = sessions.clone(); @@ -105,11 +77,10 @@ impl AcpConnection { }); Ok(Self { + auth_methods: Default::default(), server_name, client, sessions, - agent_state, - _agent_state_task: agent_state_task, _session_update_task: session_update_handler_task, }) } @@ -154,6 +125,7 @@ impl AgentConnection for AcpConnection { ) -> Task>> { let client = self.client.client(); let sessions = self.sessions.clone(); + let auth_methods = self.auth_methods.clone(); let cwd = cwd.to_path_buf(); cx.spawn(async move |cx| { let client = client.context("MCP server is not initialized yet")?; @@ -194,12 +166,18 @@ impl AgentConnection for AcpConnection { response.structured_content.context("Empty response")?, )?; + auth_methods.replace(result.auth_methods); + + let Some(session_id) = result.session_id else { + anyhow::bail!(AuthRequired); + }; + let thread = cx.new(|cx| { AcpThread::new( self.server_name, self.clone(), project, - result.session_id.clone(), + session_id.clone(), cx, ) })?; @@ -211,14 +189,14 @@ impl AgentConnection for AcpConnection { cancel_tx: None, _mcp_server: mcp_server, }; - sessions.borrow_mut().insert(result.session_id, session); + sessions.borrow_mut().insert(session_id, session); Ok(thread) }) } - fn state(&self) -> Ref<'_, acp::AgentState> { - self.agent_state.borrow() + fn auth_methods(&self) -> Vec { + self.auth_methods.borrow().clone() } fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 0f49403a0b..736fdd2726 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,7 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; -use std::cell::{Ref, RefCell}; +use std::cell::RefCell; use std::fmt::Display; use std::path::Path; use std::rc::Rc; @@ -58,7 +58,6 @@ impl AgentServer for ClaudeCode { _cx: &mut App, ) -> Task>> { let connection = ClaudeAgentConnection { - agent_state: Default::default(), sessions: Default::default(), }; @@ -67,7 +66,6 @@ impl AgentServer for ClaudeCode { } struct ClaudeAgentConnection { - agent_state: Rc>, sessions: Rc>>, } @@ -185,8 +183,8 @@ impl AgentConnection for ClaudeAgentConnection { }) } - fn state(&self) -> Ref<'_, acp::AgentState> { - self.agent_state.borrow() + fn auth_methods(&self) -> Vec { + vec![] } fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 824748a0aa..6d7684bbfc 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -216,15 +216,6 @@ impl AcpThreadView { } }; - if connection.state().needs_authentication { - this.update(cx, |this, cx| { - this.thread_state = ThreadState::Unauthenticated { connection }; - cx.notify(); - }) - .ok(); - return; - } - let result = match connection .clone() .new_thread(project.clone(), &root_dir, cx) @@ -233,7 +224,7 @@ impl AcpThreadView { Err(e) => { let mut cx = cx.clone(); // todo! remove duplication - if e.downcast_ref::().is_some() { + if e.downcast_ref::().is_some() { this.update(&mut cx, |this, cx| { this.thread_state = ThreadState::Unauthenticated { connection }; cx.notify(); @@ -2219,17 +2210,14 @@ impl Render for AcpThreadView { .justify_center() .child(self.render_pending_auth_state()) .child(h_flex().mt_1p5().justify_center().children( - connection.state().auth_methods.iter().map(|method| { - Button::new( - SharedString::from(method.id.0.clone()), - method.label.clone(), - ) - .on_click({ - let method_id = method.id.clone(); - cx.listener(move |this, _, window, cx| { - this.authenticate(method_id.clone(), window, cx) + connection.auth_methods().into_iter().map(|method| { + Button::new(SharedString::from(method.id.0.clone()), method.label) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) + }) }) - }) }), )), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), From 02d3043ec560a7e7bdc1df47efc70c4a2e4fb378 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Wed, 30 Jul 2025 14:28:01 -0300 Subject: [PATCH 10/25] Rename arg to experimental-mcp --- crates/agent_servers/src/gemini.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 9b7fde42bf..77e7d1063f 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -14,7 +14,7 @@ use crate::AllAgentServersSettings; #[derive(Clone)] pub struct Gemini; -const ACP_ARG: &str = "--experimental-acp"; +const MCP_ARG: &str = "--experimental-mcp"; impl AgentServer for Gemini { fn name(&self) -> &'static str { @@ -48,7 +48,7 @@ impl AgentServer for Gemini { })?; let Some(command) = - AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await + AgentServerCommand::resolve("gemini", &[MCP_ARG], settings, &project, cx).await else { anyhow::bail!("Failed to find gemini binary"); }; From 8563ed22521b98224bdc7b8d1c295a490b5f532f Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 1 Aug 2025 17:07:50 -0300 Subject: [PATCH 11/25] Compiling --- Cargo.lock | 16 +- crates/acp_thread/src/acp_thread.rs | 7 +- crates/acp_thread/src/connection.rs | 4 +- crates/acp_thread/src/old_acp_support.rs | 11 +- crates/agent_servers/src/acp_connection.rs | 348 +++++++++------------ crates/agent_servers/src/agent_servers.rs | 4 - crates/agent_servers/src/claude.rs | 6 +- crates/agent_servers/src/codex.rs | 78 ----- crates/agent_servers/src/e2e_tests.rs | 3 - crates/agent_servers/src/gemini.rs | 10 +- crates/agent_servers/src/mcp_server.rs | 208 ------------ crates/agent_servers/src/settings.rs | 11 +- crates/agent_ui/src/acp/thread_view.rs | 15 +- crates/agent_ui/src/agent_panel.rs | 33 -- crates/agent_ui/src/agent_ui.rs | 2 - 15 files changed, 188 insertions(+), 568 deletions(-) delete mode 100644 crates/agent_servers/src/codex.rs delete mode 100644 crates/agent_servers/src/mcp_server.rs diff --git a/Cargo.lock b/Cargo.lock index f4c328c957..f31ecdef99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -140,6 +140,10 @@ dependencies = [ name = "agent-client-protocol" version = "0.0.13" dependencies = [ + "anyhow", + "futures 0.3.31", + "log", + "parking_lot", "schemars", "serde", "serde_json", @@ -9624,9 +9628,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -11315,9 +11319,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -11325,9 +11329,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index bc2f8c756a..3bf6134862 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -999,7 +999,7 @@ impl AcpThread { let result = this .update(cx, |this, cx| { this.connection.prompt( - acp::PromptArguments { + acp::PromptRequest { prompt: message, session_id: this.session_id.clone(), }, @@ -1595,6 +1595,11 @@ mod tests { connection, child_status: io_task, current_thread: thread_rc, + auth_methods: [acp::AuthMethod { + id: acp::AuthMethodId("acp-old-no-id".into()), + label: "Log in".into(), + description: None, + }], }; AcpThread::new( diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 11f1fcc94c..929500a67b 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -16,11 +16,11 @@ pub trait AgentConnection { cx: &mut AsyncApp, ) -> Task>>; - fn auth_methods(&self) -> Vec; + fn auth_methods(&self) -> &[acp::AuthMethod]; fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task>; + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task>; fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); } diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs index 88313e0fd5..adb27e21c4 100644 --- a/crates/acp_thread/src/old_acp_support.rs +++ b/crates/acp_thread/src/old_acp_support.rs @@ -356,6 +356,7 @@ pub struct OldAcpAgentConnection { pub connection: acp_old::AgentConnection, pub child_status: Task>, pub current_thread: Rc>>, + pub auth_methods: [acp::AuthMethod; 1], } impl AgentConnection for OldAcpAgentConnection { @@ -391,12 +392,8 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn auth_methods(&self) -> Vec { - vec![acp::AuthMethod { - id: acp::AuthMethodId("acp-old-no-id".into()), - label: "Log in".into(), - description: None, - }] + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods } fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task> { @@ -409,7 +406,7 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let chunks = params .prompt .into_iter() diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs index bfb4d8b40f..ca9ec2aea0 100644 --- a/crates/agent_servers/src/acp_connection.rs +++ b/crates/agent_servers/src/acp_connection.rs @@ -1,123 +1,87 @@ use agent_client_protocol as acp; -use anyhow::anyhow; use collections::HashMap; -use context_server::listener::McpServerTool; -use context_server::types::requests; -use context_server::{ContextServer, ContextServerCommand, ContextServerId}; -use futures::channel::{mpsc, oneshot}; +use futures::channel::oneshot; use project::Project; -use smol::stream::StreamExt as _; use std::cell::RefCell; +use std::path::Path; use std::rc::Rc; -use std::{path::Path, sync::Arc}; use util::ResultExt; -use anyhow::{Context, Result}; +use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; -use crate::mcp_server::ZedMcpServer; -use crate::{AgentServerCommand, mcp_server}; +use crate::AgentServerCommand; use acp_thread::{AcpThread, AgentConnection, AuthRequired}; pub struct AcpConnection { - auth_methods: Rc>>, server_name: &'static str, - context_server: Arc, + connection: Rc, sessions: Rc>>, - _session_update_task: Task<()>, + auth_methods: Vec, + _io_task: Task>, +} + +pub struct AcpSession { + thread: WeakEntity, } impl AcpConnection { pub async fn stdio( server_name: &'static str, command: AgentServerCommand, - working_directory: Option>, + root_dir: &Path, cx: &mut AsyncApp, ) -> Result { - let context_server: Arc = ContextServer::stdio( - ContextServerId(format!("{}-mcp-server", server_name).into()), - ContextServerCommand { - path: command.path, - args: command.args, - env: command.env, - }, - working_directory, - ) - .into(); + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter().map(|arg| arg.as_str())) + .envs(command.env.iter().flatten()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; - let (notification_tx, mut notification_rx) = mpsc::unbounded(); + let stdout = child.stdout.take().expect("Failed to take stdout"); + let stdin = child.stdin.take().expect("Failed to take stdin"); let sessions = Rc::new(RefCell::new(HashMap::default())); - let session_update_handler_task = cx.spawn({ - let sessions = sessions.clone(); - async move |cx| { - while let Some(notification) = notification_rx.next().await { - Self::handle_session_notification(notification, sessions.clone(), cx) - } + let client = ClientDelegate { + sessions: sessions.clone(), + cx: cx.clone(), + }; + let (connection, io_task) = acp::AgentConnection::new(client, stdin, stdout, { + let foreground_executor = cx.foreground_executor().clone(); + move |fut| { + foreground_executor.spawn(fut).detach(); } }); - context_server - .start_with_handlers( - vec![(acp::AGENT_METHODS.session_update, { - Box::new(move |notification, _cx| { - let notification_tx = notification_tx.clone(); - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); + let io_task = cx.background_spawn(io_task); - if let Some(notification) = - serde_json::from_value::(notification) - .log_err() - { - notification_tx.unbounded_send(notification).ok(); - } - }) - })], - cx, - ) + let response = connection + .initialize(acp::InitializeRequest { + protocol_version: acp::VERSION, + client_capabilities: acp::ClientCapabilities { + fs: acp::FileSystemCapability { + read_text_file: true, + write_text_file: true, + }, + }, + }) .await?; + // todo! check version + Ok(Self { - auth_methods: Default::default(), + auth_methods: response.auth_methods, + connection: connection.into(), server_name, - context_server, sessions, - _session_update_task: session_update_handler_task, + _io_task: io_task, }) } - - pub fn handle_session_notification( - notification: acp::SessionNotification, - threads: Rc>>, - cx: &mut AsyncApp, - ) { - let threads = threads.borrow(); - let Some(thread) = threads - .get(¬ification.session_id) - .and_then(|session| session.thread.upgrade()) - else { - log::error!( - "Thread not found for session ID: {}", - notification.session_id - ); - return; - }; - - thread - .update(cx, |thread, cx| { - thread.handle_session_update(notification.update, cx) - }) - .log_err(); - } -} - -pub struct AcpSession { - thread: WeakEntity, - cancel_tx: Option>, - _mcp_server: ZedMcpServer, } impl AgentConnection for AcpConnection { @@ -127,52 +91,19 @@ impl AgentConnection for AcpConnection { cwd: &Path, cx: &mut AsyncApp, ) -> Task>> { - let client = self.context_server.client(); + let conn = self.connection.clone(); let sessions = self.sessions.clone(); - let auth_methods = self.auth_methods.clone(); let cwd = cwd.to_path_buf(); cx.spawn(async move |cx| { - let client = client.context("MCP server is not initialized yet")?; - let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); - - let mcp_server = ZedMcpServer::new(thread_rx, cx).await?; - - let response = client - .request::(context_server::types::CallToolParams { - name: acp::AGENT_METHODS.new_session.into(), - arguments: Some(serde_json::to_value(acp::NewSessionArguments { - mcp_servers: vec![mcp_server.server_config()?], - client_tools: acp::ClientTools { - request_permission: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::RequestPermissionTool::NAME.into(), - }), - read_text_file: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::ReadTextFileTool::NAME.into(), - }), - write_text_file: Some(acp::McpToolId { - mcp_server: mcp_server::SERVER_NAME.into(), - tool_name: mcp_server::WriteTextFileTool::NAME.into(), - }), - }, - cwd, - })?), - meta: None, + let response = conn + .new_session(acp::NewSessionRequest { + // todo! Zed MCP server? + mcp_servers: vec![], + cwd, }) .await?; - if response.is_error.unwrap_or_default() { - return Err(anyhow!(response.text_contents())); - } - - let result = serde_json::from_value::( - response.structured_content.context("Empty response")?, - )?; - - auth_methods.replace(result.auth_methods); - - let Some(session_id) = result.session_id else { + let Some(session_id) = response.session_id else { anyhow::bail!(AuthRequired); }; @@ -186,12 +117,8 @@ impl AgentConnection for AcpConnection { ) })?; - thread_tx.send(thread.downgrade())?; - let session = AcpSession { thread: thread.downgrade(), - cancel_tx: None, - _mcp_server: mcp_server, }; sessions.borrow_mut().insert(session_id, session); @@ -199,94 +126,115 @@ impl AgentConnection for AcpConnection { }) } - fn auth_methods(&self) -> Vec { - self.auth_methods.borrow().clone() + fn auth_methods(&self) -> &[acp::AuthMethod] { + &self.auth_methods } fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { - let client = self.context_server.client(); + let conn = self.connection.clone(); cx.foreground_executor().spawn(async move { - let params = acp::AuthenticateArguments { method_id }; - - let response = client - .context("MCP server is not initialized yet")? - .request::(context_server::types::CallToolParams { - name: acp::AGENT_METHODS.authenticate.into(), - arguments: Some(serde_json::to_value(params)?), - meta: None, + let result = conn + .authenticate(acp::AuthenticateRequest { + method_id: method_id.clone(), }) .await?; - if response.is_error.unwrap_or_default() { - Err(anyhow!(response.text_contents())) - } else { - Ok(()) - } + Ok(result) }) } - fn prompt( + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + cx.foreground_executor() + .spawn(async move { Ok(conn.prompt(params).await?) }) + } + + fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { + self.connection.cancel(session_id.clone()).log_err(); + } +} + +struct ClientDelegate { + sessions: Rc>>, + cx: AsyncApp, +} + +impl acp::Client for ClientDelegate { + async fn request_permission( &self, - params: agent_client_protocol::PromptArguments, - cx: &mut App, - ) -> Task> { - let client = self.context_server.client(); - let sessions = self.sessions.clone(); + arguments: acp::RequestPermissionRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let result = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx) + })? + .await; - cx.foreground_executor().spawn(async move { - let client = client.context("MCP server is not initialized yet")?; + let outcome = match result { + Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option }, + Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled, + }; - let (new_cancel_tx, cancel_rx) = oneshot::channel(); - { - let mut sessions = sessions.borrow_mut(); - let session = sessions - .get_mut(¶ms.session_id) - .context("Session not found")?; - session.cancel_tx.replace(new_cancel_tx); - } - - let result = client - .request_with::( - context_server::types::CallToolParams { - name: acp::AGENT_METHODS.prompt.into(), - arguments: Some(serde_json::to_value(params)?), - meta: None, - }, - Some(cancel_rx), - None, - ) - .await; - - if let Err(err) = &result - && err.is::() - { - return Ok(()); - } - - let response = result?; - - if response.is_error.unwrap_or_default() { - return Err(anyhow!(response.text_contents())); - } - - Ok(()) - }) + Ok(acp::RequestPermissionResponse { outcome }) } - fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) { - let mut sessions = self.sessions.borrow_mut(); + async fn write_text_file( + &self, + arguments: acp::WriteTextFileRequest, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + self.sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.write_text_file(arguments.path, arguments.content, cx) + })? + .await?; - if let Some(cancel_tx) = sessions - .get_mut(session_id) - .and_then(|session| session.cancel_tx.take()) - { - cancel_tx.send(()).ok(); - } - } -} - -impl Drop for AcpConnection { - fn drop(&mut self) { - self.context_server.stop().log_err(); + Ok(()) + } + + async fn read_text_file( + &self, + arguments: acp::ReadTextFileRequest, + ) -> Result { + let cx = &mut self.cx.clone(); + let content = self + .sessions + .borrow() + .get(&arguments.session_id) + .context("Failed to get session")? + .thread + .update(cx, |thread, cx| { + thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx) + })? + .await?; + + Ok(acp::ReadTextFileResponse { content }) + } + + async fn session_notification( + &self, + notification: acp::SessionNotification, + ) -> Result<(), acp::Error> { + let cx = &mut self.cx.clone(); + let sessions = self.sessions.borrow(); + let session = sessions + .get(¬ification.session_id) + .context("Failed to get session")?; + + session.thread.update(cx, |thread, cx| { + thread.handle_session_update(notification.update, cx) + })??; + + Ok(()) } } diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 6a031a190e..13bad53cd9 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,15 +1,12 @@ mod acp_connection; mod claude; -mod codex; mod gemini; -mod mcp_server; mod settings; #[cfg(test)] mod e2e_tests; pub use claude::*; -pub use codex::*; pub use gemini::*; pub use settings::*; @@ -39,7 +36,6 @@ pub trait AgentServer: Send { fn connect( &self, - // these will go away when old_acp is fully removed root_dir: &Path, project: &Entity, cx: &mut App, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 736fdd2726..9040b83085 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -183,15 +183,15 @@ impl AgentConnection for ClaudeAgentConnection { }) } - fn auth_methods(&self) -> Vec { - vec![] + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] } fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task> { Task::ready(Err(anyhow!("Authentication not supported"))) } - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let sessions = self.sessions.borrow(); let Some(session) = sessions.get(¶ms.session_id) else { return Task::ready(Err(anyhow!( diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs deleted file mode 100644 index 3e774ed83e..0000000000 --- a/crates/agent_servers/src/codex.rs +++ /dev/null @@ -1,78 +0,0 @@ -use project::Project; -use settings::SettingsStore; -use std::path::Path; -use std::rc::Rc; - -use anyhow::Result; -use gpui::{App, Entity, Task}; - -use crate::acp_connection::AcpConnection; -use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::AgentConnection; - -#[derive(Clone)] -pub struct Codex; - -impl AgentServer for Codex { - fn name(&self) -> &'static str { - "Codex" - } - - fn empty_state_headline(&self) -> &'static str { - "Welcome to Codex" - } - - fn empty_state_message(&self) -> &'static str { - "What can I help with?" - } - - fn logo(&self) -> ui::IconName { - ui::IconName::AiOpenAi - } - - fn connect( - &self, - _root_dir: &Path, - project: &Entity, - cx: &mut App, - ) -> Task>> { - let project = project.clone(); - let server_name = self.name(); - let working_directory = project.read(cx).active_project_directory(cx); - cx.spawn(async move |cx| { - let settings = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).codex.clone() - })?; - - let Some(command) = - AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await - else { - anyhow::bail!("Failed to find codex binary"); - }; - // todo! check supported version - - let conn = AcpConnection::stdio(server_name, command, working_directory, cx).await?; - Ok(Rc::new(conn) as _) - }) - } -} - -#[cfg(test)] -pub(crate) mod tests { - use super::*; - use crate::AgentServerCommand; - use std::path::Path; - - crate::common_e2e_tests!(Codex, allow_option_id = "approve"); - - pub fn local_command() -> AgentServerCommand { - let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) - .join("../../../codex/codex-rs/target/debug/codex"); - - AgentServerCommand { - path: cli_path, - args: vec![], - env: None, - } - } -} diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index e9c72eabc9..16bf1e6b47 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -375,9 +375,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc { gemini: Some(AgentServerSettings { command: crate::gemini::tests::local_command(), }), - codex: Some(AgentServerSettings { - command: crate::codex::tests::local_command(), - }), }, cx, ); diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 77e7d1063f..372ce76aa9 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -14,7 +14,7 @@ use crate::AllAgentServersSettings; #[derive(Clone)] pub struct Gemini; -const MCP_ARG: &str = "--experimental-mcp"; +const ACP_ARG: &str = "--experimental-acp"; impl AgentServer for Gemini { fn name(&self) -> &'static str { @@ -35,26 +35,26 @@ impl AgentServer for Gemini { fn connect( &self, - _root_dir: &Path, + root_dir: &Path, project: &Entity, cx: &mut App, ) -> Task>> { let project = project.clone(); let server_name = self.name(); - let working_directory = project.read(cx).active_project_directory(cx); + let root_dir = root_dir.to_path_buf(); cx.spawn(async move |cx| { let settings = cx.read_global(|settings: &SettingsStore, _| { settings.get::(None).gemini.clone() })?; let Some(command) = - AgentServerCommand::resolve("gemini", &[MCP_ARG], settings, &project, cx).await + AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await else { anyhow::bail!("Failed to find gemini binary"); }; // todo! check supported version - let conn = AcpConnection::stdio(server_name, command, working_directory, cx).await?; + let conn = AcpConnection::stdio(server_name, command, &root_dir, cx).await?; Ok(Rc::new(conn) as _) }) } diff --git a/crates/agent_servers/src/mcp_server.rs b/crates/agent_servers/src/mcp_server.rs deleted file mode 100644 index ec655800ed..0000000000 --- a/crates/agent_servers/src/mcp_server.rs +++ /dev/null @@ -1,208 +0,0 @@ -use acp_thread::AcpThread; -use agent_client_protocol as acp; -use anyhow::Result; -use context_server::listener::{McpServerTool, ToolResponse}; -use context_server::types::{ - Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, - ToolsCapabilities, requests, -}; -use futures::channel::oneshot; -use gpui::{App, AsyncApp, Task, WeakEntity}; -use indoc::indoc; - -pub struct ZedMcpServer { - server: context_server::listener::McpServer, -} - -pub const SERVER_NAME: &str = "zed"; - -impl ZedMcpServer { - pub async fn new( - thread_rx: watch::Receiver>, - cx: &AsyncApp, - ) -> Result { - let mut mcp_server = context_server::listener::McpServer::new(cx).await?; - mcp_server.handle_request::(Self::handle_initialize); - - mcp_server.add_tool(RequestPermissionTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(ReadTextFileTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(WriteTextFileTool { - thread_rx: thread_rx.clone(), - }); - - Ok(Self { server: mcp_server }) - } - - pub fn server_config(&self) -> Result { - #[cfg(not(test))] - let zed_path = anyhow::Context::context( - std::env::current_exe(), - "finding current executable path for use in mcp_server", - )?; - - #[cfg(test)] - let zed_path = crate::e2e_tests::get_zed_path(); - - Ok(acp::McpServer { - name: SERVER_NAME.into(), - command: zed_path, - args: vec![ - "--nc".into(), - self.server.socket_path().display().to_string(), - ], - env: vec![], - }) - } - - fn handle_initialize(_: InitializeParams, cx: &App) -> Task> { - cx.foreground_executor().spawn(async move { - Ok(InitializeResponse { - protocol_version: ProtocolVersion("2025-06-18".into()), - capabilities: ServerCapabilities { - experimental: None, - logging: None, - completions: None, - prompts: None, - resources: None, - tools: Some(ToolsCapabilities { - list_changed: Some(false), - }), - }, - server_info: Implementation { - name: SERVER_NAME.into(), - version: "0.1.0".into(), - }, - meta: None, - }) - }) - } -} - -// Tools - -#[derive(Clone)] -pub struct RequestPermissionTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for RequestPermissionTool { - type Input = acp::RequestPermissionArguments; - type Output = acp::RequestPermissionOutput; - - const NAME: &'static str = "Confirmation"; - - fn description(&self) -> &'static str { - indoc! {" - Request permission for tool calls. - - This tool is meant to be called programmatically by the agent loop, not the LLM. - "} - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let result = thread - .update(cx, |thread, cx| { - thread.request_tool_call_permission(input.tool_call, input.options, cx) - })? - .await; - - let outcome = match result { - Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id }, - Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled, - }; - - Ok(ToolResponse { - content: vec![], - structured_content: acp::RequestPermissionOutput { outcome }, - }) - } -} - -#[derive(Clone)] -pub struct ReadTextFileTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for ReadTextFileTool { - type Input = acp::ReadTextFileArguments; - type Output = acp::ReadTextFileOutput; - - const NAME: &'static str = "Read"; - - fn description(&self) -> &'static str { - "Reads the content of the given file in the project including unsaved changes." - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(input.path, input.line, input.limit, false, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![], - structured_content: acp::ReadTextFileOutput { content }, - }) - } -} - -#[derive(Clone)] -pub struct WriteTextFileTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for WriteTextFileTool { - type Input = acp::WriteTextFileArguments; - type Output = (); - - const NAME: &'static str = "Write"; - - fn description(&self) -> &'static str { - "Write to a file replacing its contents" - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - thread - .update(cx, |thread, cx| { - thread.write_text_file(input.path, input.content, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![], - structured_content: (), - }) - } -} diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs index aeb34a5e61..645674b5f1 100644 --- a/crates/agent_servers/src/settings.rs +++ b/crates/agent_servers/src/settings.rs @@ -13,7 +13,6 @@ pub fn init(cx: &mut App) { pub struct AllAgentServersSettings { pub gemini: Option, pub claude: Option, - pub codex: Option, } #[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] @@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings { fn load(sources: SettingsSources, _: &mut App) -> Result { let mut settings = AllAgentServersSettings::default(); - for AllAgentServersSettings { - gemini, - claude, - codex, - } in sources.defaults_and_customizations() - { + for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() { if gemini.is_some() { settings.gemini = gemini.clone(); } if claude.is_some() { settings.claude = claude.clone(); } - if codex.is_some() { - settings.codex = codex.clone(); - } } Ok(settings) diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 6d7684bbfc..17575e42db 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -2211,13 +2211,16 @@ impl Render for AcpThreadView { .child(self.render_pending_auth_state()) .child(h_flex().mt_1p5().justify_center().children( connection.auth_methods().into_iter().map(|method| { - Button::new(SharedString::from(method.id.0.clone()), method.label) - .on_click({ - let method_id = method.id.clone(); - cx.listener(move |this, _, window, cx| { - this.authenticate(method_id.clone(), window, cx) - }) + Button::new( + SharedString::from(method.id.0.clone()), + method.label.clone(), + ) + .on_click({ + let method_id = method.id.clone(); + cx.listener(move |this, _, window, cx| { + this.authenticate(method_id.clone(), window, cx) }) + }) }), )), ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)), diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 91217cb030..875320372d 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1991,20 +1991,6 @@ impl AgentPanel { ); }), ) - .item( - ContextMenuEntry::new("New Codex Thread") - .icon(IconName::AiOpenAi) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some(crate::ExternalAgent::Codex), - } - .boxed_clone(), - cx, - ); - }), - ) }); menu })) @@ -2666,25 +2652,6 @@ impl AgentPanel { ) }, ), - ) - .child( - NewThreadButton::new( - "new-codex-thread-btn", - "New Codex Thread", - IconName::AiOpenAi, - ) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::Codex, - ), - }), - cx, - ) - }, - ), ), ) }), diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 4b75cc9e77..6ae78585de 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -150,7 +150,6 @@ enum ExternalAgent { #[default] Gemini, ClaudeCode, - Codex, } impl ExternalAgent { @@ -158,7 +157,6 @@ impl ExternalAgent { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), - ExternalAgent::Codex => Rc::new(agent_servers::Codex), } } } From 8acb58b6e50fa26128675a9b8f60ee82a28c90dd Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 1 Aug 2025 18:15:26 -0300 Subject: [PATCH 12/25] Fix types --- crates/agent_servers/src/acp_connection.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs index ca9ec2aea0..0ced22fc65 100644 --- a/crates/agent_servers/src/acp_connection.rs +++ b/crates/agent_servers/src/acp_connection.rs @@ -1,11 +1,10 @@ -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp, Agent as _}; use collections::HashMap; use futures::channel::oneshot; use project::Project; use std::cell::RefCell; use std::path::Path; use std::rc::Rc; -use util::ResultExt; use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; @@ -15,7 +14,7 @@ use acp_thread::{AcpThread, AgentConnection, AuthRequired}; pub struct AcpConnection { server_name: &'static str, - connection: Rc, + connection: Rc, sessions: Rc>>, auth_methods: Vec, _io_task: Task>, @@ -51,7 +50,7 @@ impl AcpConnection { sessions: sessions.clone(), cx: cx.clone(), }; - let (connection, io_task) = acp::AgentConnection::new(client, stdin, stdout, { + let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, { let foreground_executor = cx.foreground_executor().clone(); move |fut| { foreground_executor.spawn(fut).detach(); @@ -149,8 +148,14 @@ impl AgentConnection for AcpConnection { .spawn(async move { Ok(conn.prompt(params).await?) }) } - fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { - self.connection.cancel(session_id.clone()).log_err(); + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + let conn = self.connection.clone(); + let params = acp::CancelledNotification { + session_id: session_id.clone(), + }; + cx.foreground_executor() + .spawn(async move { conn.cancelled(params).await }) + .detach(); } } From 8890f590b1f1791e3640c188c62859c00843f59a Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 1 Aug 2025 22:25:04 -0600 Subject: [PATCH 13/25] Fix some breakages against agent-client-protocol/main --- crates/agent_ui/src/acp/thread_view.rs | 38 ++++++++++++++++---------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 167f7e6136..26166b6960 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -2844,10 +2844,6 @@ mod tests { } impl AgentConnection for StubAgentConnection { - fn name(&self) -> &'static str { - "StubAgentConnection" - } - fn new_thread( self: Rc, project: Entity, @@ -2863,17 +2859,27 @@ mod tests { .into(), ); let thread = cx - .new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx)) + .new(|cx| { + AcpThread::new("New Thread", self.clone(), project, session_id.clone(), cx) + }) .unwrap(); self.sessions.lock().insert(session_id, thread.downgrade()); Task::ready(Ok(thread)) } - fn authenticate(&self, _cx: &mut App) -> Task> { + fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] { + todo!() + } + + fn authenticate( + &self, + _method: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { unimplemented!() } - fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let sessions = self.sessions.lock(); let thread = sessions.get(¶ms.session_id).unwrap(); let mut tasks = vec![]; @@ -2920,10 +2926,6 @@ mod tests { struct SaboteurAgentConnection; impl AgentConnection for SaboteurAgentConnection { - fn name(&self) -> &'static str { - "SaboteurAgentConnection" - } - fn new_thread( self: Rc, project: Entity, @@ -2931,15 +2933,23 @@ mod tests { cx: &mut gpui::AsyncApp, ) -> Task>> { Task::ready(Ok(cx - .new(|cx| AcpThread::new(self, project, SessionId("test".into()), cx)) + .new(|cx| AcpThread::new("New Thread", self, project, SessionId("test".into()), cx)) .unwrap())) } - fn authenticate(&self, _cx: &mut App) -> Task> { + fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] { + todo!() + } + + fn authenticate( + &self, + _method: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { unimplemented!() } - fn prompt(&self, _params: acp::PromptArguments, _cx: &mut App) -> Task> { + fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task> { Task::ready(Err(anyhow::anyhow!("Error prompting"))) } From afb5c4147ad665f3e88309fea4e54a5875ac6b7a Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 1 Aug 2025 22:32:36 -0600 Subject: [PATCH 14/25] WIP: Add agent2 crate from test-driven-agent branch --- Cargo.lock | 33 ++ Cargo.toml | 2 + crates/agent2/Cargo.toml | 49 +++ crates/agent2/LICENSE-GPL | 1 + crates/agent2/src/agent2.rs | 6 + crates/agent2/src/prompts.rs | 29 ++ crates/agent2/src/templates.rs | 57 +++ crates/agent2/src/templates/base.hbs | 56 +++ crates/agent2/src/templates/glob.hbs | 8 + crates/agent2/src/thread.rs | 420 +++++++++++++++++++ crates/agent2/src/thread/tests.rs | 254 +++++++++++ crates/agent2/src/thread/tests/test_tools.rs | 83 ++++ crates/agent2/src/tools.rs | 1 + crates/agent2/src/tools/glob.rs | 76 ++++ 14 files changed, 1075 insertions(+) create mode 100644 crates/agent2/Cargo.toml create mode 120000 crates/agent2/LICENSE-GPL create mode 100644 crates/agent2/src/agent2.rs create mode 100644 crates/agent2/src/prompts.rs create mode 100644 crates/agent2/src/templates.rs create mode 100644 crates/agent2/src/templates/base.hbs create mode 100644 crates/agent2/src/templates/glob.hbs create mode 100644 crates/agent2/src/thread.rs create mode 100644 crates/agent2/src/thread/tests.rs create mode 100644 crates/agent2/src/thread/tests/test_tools.rs create mode 100644 crates/agent2/src/tools.rs create mode 100644 crates/agent2/src/tools/glob.rs diff --git a/Cargo.lock b/Cargo.lock index 767d9a4c9a..15d472ee32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -148,6 +148,39 @@ dependencies = [ "serde_json", ] +[[package]] +name = "agent2" +version = "0.1.0" +dependencies = [ + "anyhow", + "assistant_tool", + "assistant_tools", + "chrono", + "client", + "collections", + "ctor", + "env_logger 0.11.8", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "handlebars 4.5.0", + "language_model", + "language_models", + "parking_lot", + "project", + "reqwest_client", + "rust-embed", + "schemars", + "serde", + "serde_json", + "settings", + "smol", + "thiserror 2.0.12", + "util", + "worktree", +] + [[package]] name = "agent_servers" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 902c7e8d19..d01ae9f683 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "crates/acp_thread", "crates/activity_indicator", "crates/agent", + "crates/agent2", "crates/agent_servers", "crates/agent_settings", "crates/agent_ui", @@ -228,6 +229,7 @@ edition = "2024" acp_thread = { path = "crates/acp_thread" } agent = { path = "crates/agent" } +agent2 = { path = "crates/agent2" } activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } agent_settings = { path = "crates/agent_settings" } diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml new file mode 100644 index 0000000000..d4a234fe50 --- /dev/null +++ b/crates/agent2/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "agent2" +version = "0.1.0" +edition = "2021" +license = "GPL-3.0-or-later" +publish = false + +[lib] +path = "src/agent2.rs" + +[lints] +workspace = true + +[dependencies] +anyhow.workspace = true +assistant_tool.workspace = true +assistant_tools.workspace = true +chrono.workspace = true +collections.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +handlebars = { workspace = true, features = ["rust-embed"] } +language_model.workspace = true +language_models.workspace = true +parking_lot.workspace = true +project.workspace = true +rust-embed.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +thiserror.workspace = true +util.workspace = true +worktree.workspace = true + +[dev-dependencies] +ctor.workspace = true +client = { workspace = true, "features" = ["test-support"] } +env_logger.workspace = true +fs = { workspace = true, "features" = ["test-support"] } +gpui = { workspace = true, "features" = ["test-support"] } +gpui_tokio.workspace = true +language_model = { workspace = true, "features" = ["test-support"] } +project = { workspace = true, "features" = ["test-support"] } +reqwest_client.workspace = true +settings = { workspace = true, "features" = ["test-support"] } +worktree = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/LICENSE-GPL b/crates/agent2/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/agent2/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs new file mode 100644 index 0000000000..577bf170dd --- /dev/null +++ b/crates/agent2/src/agent2.rs @@ -0,0 +1,6 @@ +mod prompts; +mod templates; +mod thread; +mod tools; + +pub use thread::*; diff --git a/crates/agent2/src/prompts.rs b/crates/agent2/src/prompts.rs new file mode 100644 index 0000000000..f11eaf426b --- /dev/null +++ b/crates/agent2/src/prompts.rs @@ -0,0 +1,29 @@ +use crate::{ + templates::{BaseTemplate, Template, Templates, WorktreeData}, + thread::Prompt, +}; +use anyhow::Result; +use gpui::{App, Entity}; +use project::Project; + +struct BasePrompt { + project: Entity, +} + +impl Prompt for BasePrompt { + fn render(&self, templates: &Templates, cx: &App) -> Result { + BaseTemplate { + os: std::env::consts::OS.to_string(), + shell: util::get_system_shell(), + worktrees: self + .project + .read(cx) + .worktrees(cx) + .map(|worktree| WorktreeData { + root_name: worktree.read(cx).root_name().to_string(), + }) + .collect(), + } + .render(templates) + } +} diff --git a/crates/agent2/src/templates.rs b/crates/agent2/src/templates.rs new file mode 100644 index 0000000000..04569369be --- /dev/null +++ b/crates/agent2/src/templates.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use anyhow::Result; +use handlebars::Handlebars; +use rust_embed::RustEmbed; +use serde::Serialize; + +#[derive(RustEmbed)] +#[folder = "src/templates"] +#[include = "*.hbs"] +struct Assets; + +pub struct Templates(Handlebars<'static>); + +impl Templates { + pub fn new() -> Arc { + let mut handlebars = Handlebars::new(); + handlebars.register_embed_templates::().unwrap(); + Arc::new(Self(handlebars)) + } +} + +pub trait Template: Sized { + const TEMPLATE_NAME: &'static str; + + fn render(&self, templates: &Templates) -> Result + where + Self: Serialize + Sized, + { + Ok(templates.0.render(Self::TEMPLATE_NAME, self)?) + } +} + +#[derive(Serialize)] +pub struct BaseTemplate { + pub os: String, + pub shell: String, + pub worktrees: Vec, +} + +impl Template for BaseTemplate { + const TEMPLATE_NAME: &'static str = "base.hbs"; +} + +#[derive(Serialize)] +pub struct WorktreeData { + pub root_name: String, +} + +#[derive(Serialize)] +pub struct GlobTemplate { + pub project_roots: String, +} + +impl Template for GlobTemplate { + const TEMPLATE_NAME: &'static str = "glob.hbs"; +} diff --git a/crates/agent2/src/templates/base.hbs b/crates/agent2/src/templates/base.hbs new file mode 100644 index 0000000000..7eef231e32 --- /dev/null +++ b/crates/agent2/src/templates/base.hbs @@ -0,0 +1,56 @@ +You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. + +## Communication + +1. Be conversational but professional. +2. Refer to the USER in the second person and yourself in the first person. +3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. +4. NEVER lie or make things up. +5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing. + +## Tool Use + +1. Make sure to adhere to the tools schema. +2. Provide every required argument. +3. DO NOT use tools to access items that are already available in the context section. +4. Use only the tools that are currently available. +5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off. + +## Searching and Reading + +If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions. + +If appropriate, use tool calls to explore the current project, which contains the following root directories: + +{{#each worktrees}} +- `{{root_name}}` +{{/each}} + +- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above. +- When looking for symbols in the project, prefer the `grep` tool. +- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project. +- Bias towards not asking the user for help if you can find the answer yourself. + +## Fixing Diagnostics + +1. Make 1-2 attempts at fixing diagnostics, then defer to the user. +2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem. + +## Debugging + +When debugging, only make code changes if you are certain that you can solve the problem. +Otherwise, follow debugging best practices: +1. Address the root cause instead of the symptoms. +2. Add descriptive logging statements and error messages to track variable and code state. +3. Add test functions and statements to isolate the problem. + +## Calling External APIs + +1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission. +2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data. +3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed) + +## System Information + +Operating System: {{os}} +Default Shell: {{shell}} diff --git a/crates/agent2/src/templates/glob.hbs b/crates/agent2/src/templates/glob.hbs new file mode 100644 index 0000000000..3bf992b093 --- /dev/null +++ b/crates/agent2/src/templates/glob.hbs @@ -0,0 +1,8 @@ +Find paths on disk with glob patterns. + +Assume that all glob patterns are matched in a project directory with the following entries. + +{{project_roots}} + +When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be +sure to anchor them with one of the directories listed above. diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs new file mode 100644 index 0000000000..c5b3572b5f --- /dev/null +++ b/crates/agent2/src/thread.rs @@ -0,0 +1,420 @@ +use crate::templates::Templates; +use anyhow::{anyhow, Result}; +use futures::{channel::mpsc, future}; +use gpui::{App, Context, SharedString, Task}; +use language_model::{ + CompletionIntent, CompletionMode, LanguageModel, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, + LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, Role, StopReason, +}; +use schemars::{JsonSchema, Schema}; +use serde::Deserialize; +use smol::stream::StreamExt; +use std::{collections::BTreeMap, sync::Arc}; +use util::ResultExt; + +#[derive(Debug)] +pub struct AgentMessage { + pub role: Role, + pub content: Vec, +} + +pub type AgentResponseEvent = LanguageModelCompletionEvent; + +pub trait Prompt { + fn render(&self, prompts: &Templates, cx: &App) -> Result; +} + +pub struct Thread { + messages: Vec, + completion_mode: CompletionMode, + /// Holds the task that handles agent interaction until the end of the turn. + /// Survives across multiple requests as the model performs tool calls and + /// we run tools, report their results. + running_turn: Option>, + system_prompts: Vec>, + tools: BTreeMap>, + templates: Arc, + // project: Entity, + // action_log: Entity, +} + +impl Thread { + pub fn new(templates: Arc) -> Self { + Self { + messages: Vec::new(), + completion_mode: CompletionMode::Normal, + system_prompts: Vec::new(), + running_turn: None, + tools: BTreeMap::default(), + templates, + } + } + + pub fn set_mode(&mut self, mode: CompletionMode) { + self.completion_mode = mode; + } + + pub fn messages(&self) -> &[AgentMessage] { + &self.messages + } + + pub fn add_tool(&mut self, tool: impl AgentTool) { + self.tools.insert(tool.name(), tool.erase()); + } + + pub fn remove_tool(&mut self, name: &str) -> bool { + self.tools.remove(name).is_some() + } + + /// Sending a message results in the model streaming a response, which could include tool calls. + /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. + /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. + pub fn send( + &mut self, + model: Arc, + content: impl Into, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + cx.notify(); + let (events_tx, events_rx) = + mpsc::unbounded::>(); + + let system_message = self.build_system_message(cx); + self.messages.extend(system_message); + + self.messages.push(AgentMessage { + role: Role::User, + content: vec![content.into()], + }); + self.running_turn = Some(cx.spawn(async move |thread, cx| { + let turn_result = async { + // Perform one request, then keep looping if the model makes tool calls. + let mut completion_intent = CompletionIntent::UserPrompt; + loop { + let request = thread.update(cx, |thread, cx| { + thread.build_completion_request(completion_intent, cx) + })?; + + // println!( + // "request: {}", + // serde_json::to_string_pretty(&request).unwrap() + // ); + + // Stream events, appending to messages and collecting up tool uses. + let mut events = model.stream_completion(request, cx).await?; + let mut tool_uses = Vec::new(); + while let Some(event) = events.next().await { + match event { + Ok(event) => { + thread + .update(cx, |thread, cx| { + tool_uses.extend(thread.handle_streamed_completion_event( + event, + events_tx.clone(), + cx, + )); + }) + .ok(); + } + Err(error) => { + events_tx.unbounded_send(Err(error)).ok(); + break; + } + } + } + + // If there are no tool uses, the turn is done. + if tool_uses.is_empty() { + break; + } + + // If there are tool uses, wait for their results to be + // computed, then send them together in a single message on + // the next loop iteration. + let tool_results = future::join_all(tool_uses).await; + thread + .update(cx, |thread, _cx| { + thread.messages.push(AgentMessage { + role: Role::User, + content: tool_results.into_iter().map(Into::into).collect(), + }); + }) + .ok(); + completion_intent = CompletionIntent::ToolResults; + } + + Ok(()) + } + .await; + + if let Err(error) = turn_result { + events_tx.unbounded_send(Err(error)).ok(); + } + })); + events_rx + } + + pub fn build_system_message(&mut self, cx: &App) -> Option { + let mut system_message = AgentMessage { + role: Role::System, + content: Vec::new(), + }; + + for prompt in &self.system_prompts { + if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() { + system_message + .content + .push(MessageContent::Text(rendered_prompt)); + } + } + + (!system_message.content.is_empty()).then_some(system_message) + } + + /// A helper method that's called on every streamed completion event. + /// Returns an optional tool result task, which the main agentic loop in + /// send will send back to the model when it resolves. + fn handle_streamed_completion_event( + &mut self, + event: LanguageModelCompletionEvent, + events_tx: mpsc::UnboundedSender>, + cx: &mut Context, + ) -> Option> { + use LanguageModelCompletionEvent::*; + events_tx.unbounded_send(Ok(event.clone())).ok(); + + match event { + Text(new_text) => self.handle_text_event(new_text, cx), + Thinking { text, signature } => { + todo!() + } + ToolUse(tool_use) => { + return self.handle_tool_use_event(tool_use, cx); + } + StartMessage { role, .. } => { + self.messages.push(AgentMessage { + role, + content: Vec::new(), + }); + } + UsageUpdate(_) => {} + Stop(stop_reason) => self.handle_stop_event(stop_reason), + StatusUpdate(_completion_request_status) => {} + RedactedThinking { data } => todo!(), + ToolUseJsonParseError { + id, + tool_name, + raw_input, + json_parse_error, + } => todo!(), + } + + None + } + + fn handle_stop_event(&mut self, stop_reason: StopReason) { + match stop_reason { + StopReason::EndTurn | StopReason::ToolUse => {} + StopReason::MaxTokens => todo!(), + StopReason::Refusal => todo!(), + } + } + + fn handle_text_event(&mut self, new_text: String, cx: &mut Context) { + let last_message = self.last_assistant_message(); + if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { + text.push_str(&new_text); + } else { + last_message.content.push(MessageContent::Text(new_text)); + } + + cx.notify(); + } + + fn handle_tool_use_event( + &mut self, + tool_use: LanguageModelToolUse, + cx: &mut Context, + ) -> Option> { + cx.notify(); + + let last_message = self.last_assistant_message(); + + // Ensure the last message ends in the current tool use + let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { + if let MessageContent::ToolUse(last_tool_use) = content { + if last_tool_use.id == tool_use.id { + *last_tool_use = tool_use.clone(); + false + } else { + true + } + } else { + true + } + }); + if push_new_tool_use { + last_message.content.push(tool_use.clone().into()); + } + + if !tool_use.is_input_complete { + return None; + } + + if let Some(tool) = self.tools.get(tool_use.name.as_ref()) { + let pending_tool_result = tool.clone().run(tool_use.input, cx); + + Some(cx.foreground_executor().spawn(async move { + match pending_tool_result.await { + Ok(tool_output) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: false, + content: LanguageModelToolResultContent::Text(Arc::from(tool_output)), + output: None, + }, + Err(error) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), + output: None, + }, + } + })) + } else { + Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(format!( + "No tool named {} exists", + tool_use.name + ))), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })) + } + } + + /// Guarantees the last message is from the assistant and returns a mutable reference. + fn last_assistant_message(&mut self) -> &mut AgentMessage { + if self + .messages + .last() + .map_or(true, |m| m.role != Role::Assistant) + { + self.messages.push(AgentMessage { + role: Role::Assistant, + content: Vec::new(), + }); + } + self.messages.last_mut().unwrap() + } + + fn build_completion_request( + &self, + completion_intent: CompletionIntent, + cx: &mut App, + ) -> LanguageModelRequest { + LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: Some(completion_intent), + mode: Some(self.completion_mode), + messages: self.build_request_messages(), + tools: self + .tools + .values() + .filter_map(|tool| { + Some(LanguageModelRequestTool { + name: tool.name().to_string(), + description: tool.description(cx).to_string(), + input_schema: tool + .input_schema(LanguageModelToolSchemaFormat::JsonSchema) + .log_err()?, + }) + }) + .collect(), + tool_choice: None, + stop: Vec::new(), + temperature: None, + } + } + + fn build_request_messages(&self) -> Vec { + self.messages + .iter() + .map(|message| LanguageModelRequestMessage { + role: message.role, + content: message.content.clone(), + cache: false, + }) + .collect() + } +} + +pub trait AgentTool +where + Self: 'static + Sized, +{ + type Input: for<'de> Deserialize<'de> + JsonSchema; + + fn name(&self) -> SharedString; + fn description(&self, _cx: &mut App) -> SharedString { + let schema = schemars::schema_for!(Self::Input); + SharedString::new( + schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap_or_default(), + ) + } + + /// Returns the JSON schema that describes the tool's input. + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema { + assistant_tools::root_schema_for::(format) + } + + /// Runs the tool with the provided input. + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task>; + + fn erase(self) -> Arc { + Arc::new(Erased(Arc::new(self))) + } +} + +pub struct Erased(T); + +pub trait AgentToolErased { + fn name(&self) -> SharedString; + fn description(&self, cx: &mut App) -> SharedString; + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task>; +} + +impl AgentToolErased for Erased> +where + T: AgentTool, +{ + fn name(&self) -> SharedString { + self.0.name() + } + + fn description(&self, cx: &mut App) -> SharedString { + self.0.description(cx) + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + Ok(serde_json::to_value(self.0.input_schema(format))?) + } + + fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task> { + let parsed_input: Result = serde_json::from_value(input).map_err(Into::into); + match parsed_input { + Ok(input) => self.0.clone().run(input, cx), + Err(error) => Task::ready(Err(anyhow!(error))), + } + } +} diff --git a/crates/agent2/src/thread/tests.rs b/crates/agent2/src/thread/tests.rs new file mode 100644 index 0000000000..9bbf95c6cf --- /dev/null +++ b/crates/agent2/src/thread/tests.rs @@ -0,0 +1,254 @@ +use super::*; +use client::{proto::language_server_prompt_request, Client, UserStore}; +use fs::FakeFs; +use gpui::{AppContext, Entity, TestAppContext}; +use language_model::{ + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelRegistry, MessageContent, StopReason, +}; +use reqwest_client::ReqwestClient; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use smol::stream::StreamExt; +use std::{sync::Arc, time::Duration}; + +mod test_tools; +use test_tools::*; + +#[gpui::test] +async fn test_echo(cx: &mut TestAppContext) { + let AgentTest { model, agent, .. } = setup(cx).await; + + let events = agent + .update(cx, |agent, cx| { + agent.send(model.clone(), "Testing: Reply with 'Hello'", cx) + }) + .collect() + .await; + agent.update(cx, |agent, _cx| { + assert_eq!( + agent.messages.last().unwrap().content, + vec![MessageContent::Text("Hello".to_string())] + ); + }); + assert_eq!(stop_events(events), vec![StopReason::EndTurn]); +} + +#[gpui::test] +async fn test_basic_tool_calls(cx: &mut TestAppContext) { + let AgentTest { model, agent, .. } = setup(cx).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let events = agent + .update(cx, |agent, cx| { + agent.add_tool(EchoTool); + agent.send( + model.clone(), + "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", + cx, + ) + }) + .collect() + .await; + assert_eq!( + stop_events(events), + vec![StopReason::ToolUse, StopReason::EndTurn] + ); + + // Test a tool calls that's likely to complete *after* streaming stops. + let events = agent + .update(cx, |agent, cx| { + agent.remove_tool(&AgentTool::name(&EchoTool)); + agent.add_tool(DelayTool); + agent.send( + model.clone(), + "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", + cx, + ) + }) + .collect() + .await; + assert_eq!( + stop_events(events), + vec![StopReason::ToolUse, StopReason::EndTurn] + ); + agent.update(cx, |agent, _cx| { + assert!(agent + .messages + .last() + .unwrap() + .content + .iter() + .any(|content| { + if let MessageContent::Text(text) = content { + text.contains("Ding") + } else { + false + } + })); + }); +} + +#[gpui::test] +async fn test_streaming_tool_calls(cx: &mut TestAppContext) { + let AgentTest { model, agent, .. } = setup(cx).await; + + // Test a tool call that's likely to complete *before* streaming stops. + let mut events = agent.update(cx, |agent, cx| { + agent.add_tool(WordListTool); + agent.send(model.clone(), "Test the word_list tool.", cx) + }); + + let mut saw_partial_tool_use = false; + while let Some(event) = events.next().await { + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event { + agent.update(cx, |agent, _cx| { + // Look for a tool use in the agent's last message + let last_content = agent.messages().last().unwrap().content.last().unwrap(); + if let MessageContent::ToolUse(last_tool_use) = last_content { + assert_eq!(last_tool_use.name.as_ref(), "word_list"); + if tool_use_event.is_input_complete { + last_tool_use + .input + .get("a") + .expect("'a' has streamed because input is now complete"); + last_tool_use + .input + .get("g") + .expect("'g' has streamed because input is now complete"); + } else { + if !last_tool_use.is_input_complete + && last_tool_use.input.get("g").is_none() + { + saw_partial_tool_use = true; + } + } + } else { + panic!("last content should be a tool use"); + } + }); + } + } + + assert!( + saw_partial_tool_use, + "should see at least one partially streamed tool use in the history" + ); +} + +#[gpui::test] +async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { + let AgentTest { model, agent, .. } = setup(cx).await; + + // Test concurrent tool calls with different delay times + let events = agent + .update(cx, |agent, cx| { + agent.add_tool(DelayTool); + agent.send( + model.clone(), + "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", + cx, + ) + }) + .collect() + .await; + + let stop_reasons = stop_events(events); + if stop_reasons.len() == 2 { + assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]); + } else if stop_reasons.len() == 3 { + assert_eq!( + stop_reasons, + vec![ + StopReason::ToolUse, + StopReason::ToolUse, + StopReason::EndTurn + ] + ); + } else { + panic!("Expected either 1 or 2 tool uses followed by end turn"); + } + + agent.update(cx, |agent, _cx| { + let last_message = agent.messages.last().unwrap(); + let text = last_message + .content + .iter() + .filter_map(|content| { + if let MessageContent::Text(text) = content { + Some(text.as_str()) + } else { + None + } + }) + .collect::(); + + assert!(text.contains("Ding")); + }); +} + +/// Filters out the stop events for asserting against in tests +fn stop_events( + result_events: Vec>, +) -> Vec { + result_events + .into_iter() + .filter_map(|event| match event.unwrap() { + LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason), + _ => None, + }) + .collect() +} + +struct AgentTest { + model: Arc, + agent: Entity, +} + +async fn setup(cx: &mut TestAppContext) -> AgentTest { + cx.executor().allow_parking(); + cx.update(settings::init); + let fs = FakeFs::new(cx.executor().clone()); + // let project = Project::test(fs.clone(), [], cx).await; + // let action_log = cx.new(|_| ActionLog::new(project.clone())); + let templates = Templates::new(); + let agent = cx.new(|_| Thread::new(templates)); + + let model = cx + .update(|cx| { + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + let models = LanguageModelRegistry::read_global(cx); + let model = models + .available_models(cx) + .find(|model| model.id().0 == "claude-3-7-sonnet-latest") + .unwrap(); + + let provider = models.provider(&model.provider_id()).unwrap(); + let authenticated = provider.authenticate(cx); + + cx.spawn(async move |cx| { + authenticated.await.unwrap(); + model + }) + }) + .await; + + AgentTest { model, agent } +} + +#[cfg(test)] +#[ctor::ctor] +fn init_logger() { + if std::env::var("RUST_LOG").is_ok() { + env_logger::init(); + } +} diff --git a/crates/agent2/src/thread/tests/test_tools.rs b/crates/agent2/src/thread/tests/test_tools.rs new file mode 100644 index 0000000000..38da1bb2f2 --- /dev/null +++ b/crates/agent2/src/thread/tests/test_tools.rs @@ -0,0 +1,83 @@ +use super::*; + +/// A tool that echoes its input +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct EchoToolInput { + /// The text to echo. + text: String, +} + +pub struct EchoTool; + +impl AgentTool for EchoTool { + type Input = EchoToolInput; + + fn name(&self) -> SharedString { + "echo".into() + } + + fn run(self: Arc, input: Self::Input, _cx: &mut App) -> Task> { + Task::ready(Ok(input.text)) + } +} + +/// A tool that waits for a specified delay +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct DelayToolInput { + /// The delay in milliseconds. + ms: u64, +} + +pub struct DelayTool; + +impl AgentTool for DelayTool { + type Input = DelayToolInput; + + fn name(&self) -> SharedString { + "delay".into() + } + + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> + where + Self: Sized, + { + cx.foreground_executor().spawn(async move { + smol::Timer::after(Duration::from_millis(input.ms)).await; + Ok("Ding".to_string()) + }) + } +} + +/// A tool that takes an object with map from letters to random words starting with that letter. +/// All fiealds are required! Pass a word for every letter! +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct WordListInput { + /// Provide a random word that starts with A. + a: Option, + /// Provide a random word that starts with B. + b: Option, + /// Provide a random word that starts with C. + c: Option, + /// Provide a random word that starts with D. + d: Option, + /// Provide a random word that starts with E. + e: Option, + /// Provide a random word that starts with F. + f: Option, + /// Provide a random word that starts with G. + g: Option, +} + +pub struct WordListTool; + +impl AgentTool for WordListTool { + type Input = WordListInput; + + fn name(&self) -> SharedString { + "word_list".into() + } + + fn run(self: Arc, _input: Self::Input, _cx: &mut App) -> Task> { + Task::ready(Ok("ok".to_string())) + } +} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs new file mode 100644 index 0000000000..cf3162abfa --- /dev/null +++ b/crates/agent2/src/tools.rs @@ -0,0 +1 @@ +mod glob; diff --git a/crates/agent2/src/tools/glob.rs b/crates/agent2/src/tools/glob.rs new file mode 100644 index 0000000000..9434311aaf --- /dev/null +++ b/crates/agent2/src/tools/glob.rs @@ -0,0 +1,76 @@ +use anyhow::{anyhow, Result}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::Deserialize; +use std::{path::PathBuf, sync::Arc}; +use util::paths::PathMatcher; +use worktree::Snapshot as WorktreeSnapshot; + +use crate::{ + templates::{GlobTemplate, Template, Templates}, + thread::AgentTool, +}; + +// Description is dynamic, see `fn description` below +#[derive(Deserialize, JsonSchema)] +struct GlobInput { + /// A POSIX glob pattern + glob: SharedString, +} + +struct GlobTool { + project: Entity, + templates: Arc, +} + +impl AgentTool for GlobTool { + type Input = GlobInput; + + fn name(&self) -> SharedString { + "glob".into() + } + + fn description(&self, cx: &mut App) -> SharedString { + let project_roots = self + .project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).root_name().into()) + .collect::>() + .join("\n"); + + GlobTemplate { project_roots } + .render(&self.templates) + .expect("template failed to render") + .into() + } + + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> { + let path_matcher = match PathMatcher::new([&input.glob]) { + Ok(matcher) => matcher, + Err(error) => return Task::ready(Err(anyhow!(error))), + }; + + let snapshots: Vec = self + .project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect(); + + cx.background_spawn(async move { + let paths = snapshots.iter().flat_map(|snapshot| { + let root_name = PathBuf::from(snapshot.root_name()); + snapshot + .entries(false, 0) + .map(move |entry| root_name.join(&entry.path)) + .filter(|path| path_matcher.is_match(&path)) + }); + let output = paths + .map(|path| format!("{}\n", path.display())) + .collect::(); + Ok(output) + }) + } +} From 84d6a0fae99e6f9c0cf4dcc4eaa8c6ade50ce945 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 1 Aug 2025 22:39:08 -0600 Subject: [PATCH 15/25] Fix agent2 compilation errors and warnings - Add cloud_llm_client dependency for CompletionIntent and CompletionMode - Fix LanguageModelRequest initialization with missing thinking_allowed field - Update StartMessage handling to use Assistant role - Fix MessageContent conversions to use enum variants directly - Fix input_schema implementation to use schemars directly - Suppress unused variable and dead code warnings --- Cargo.lock | 1 + crates/agent2/Cargo.toml | 1 + crates/agent2/src/prompts.rs | 5 +++-- crates/agent2/src/thread.rs | 42 ++++++++++++++++++++++-------------- 4 files changed, 31 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 15d472ee32..d2d4798880 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -157,6 +157,7 @@ dependencies = [ "assistant_tools", "chrono", "client", + "cloud_llm_client", "collections", "ctor", "env_logger 0.11.8", diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index d4a234fe50..c1ce775ae6 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -16,6 +16,7 @@ anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true chrono.workspace = true +cloud_llm_client.workspace = true collections.workspace = true fs.workspace = true futures.workspace = true diff --git a/crates/agent2/src/prompts.rs b/crates/agent2/src/prompts.rs index f11eaf426b..ab30b117a8 100644 --- a/crates/agent2/src/prompts.rs +++ b/crates/agent2/src/prompts.rs @@ -6,11 +6,12 @@ use anyhow::Result; use gpui::{App, Entity}; use project::Project; -struct BasePrompt { +#[allow(dead_code)] +struct _BasePrompt { project: Entity, } -impl Prompt for BasePrompt { +impl Prompt for _BasePrompt { fn render(&self, templates: &Templates, cx: &App) -> Result { BaseTemplate { os: std::env::consts::OS.to_string(), diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index c5b3572b5f..85e2030f8b 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,12 +1,13 @@ use crate::templates::Templates; use anyhow::{anyhow, Result}; +use cloud_llm_client::{CompletionIntent, CompletionMode}; use futures::{channel::mpsc, future}; use gpui::{App, Context, SharedString, Task}; use language_model::{ - CompletionIntent, CompletionMode, LanguageModel, LanguageModelCompletionError, - LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, - LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, Role, StopReason, + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, + LanguageModelToolUse, MessageContent, Role, StopReason, }; use schemars::{JsonSchema, Schema}; use serde::Deserialize; @@ -138,7 +139,10 @@ impl Thread { .update(cx, |thread, _cx| { thread.messages.push(AgentMessage { role: Role::User, - content: tool_results.into_iter().map(Into::into).collect(), + content: tool_results + .into_iter() + .map(MessageContent::ToolResult) + .collect(), }); }) .ok(); @@ -187,27 +191,30 @@ impl Thread { match event { Text(new_text) => self.handle_text_event(new_text, cx), - Thinking { text, signature } => { + Thinking { + text: _text, + signature: _signature, + } => { todo!() } ToolUse(tool_use) => { return self.handle_tool_use_event(tool_use, cx); } - StartMessage { role, .. } => { + StartMessage { .. } => { self.messages.push(AgentMessage { - role, + role: Role::Assistant, content: Vec::new(), }); } UsageUpdate(_) => {} Stop(stop_reason) => self.handle_stop_event(stop_reason), StatusUpdate(_completion_request_status) => {} - RedactedThinking { data } => todo!(), + RedactedThinking { data: _data } => todo!(), ToolUseJsonParseError { - id, - tool_name, - raw_input, - json_parse_error, + id: _id, + tool_name: _tool_name, + raw_input: _raw_input, + json_parse_error: _json_parse_error, } => todo!(), } @@ -256,7 +263,9 @@ impl Thread { } }); if push_new_tool_use { - last_message.content.push(tool_use.clone().into()); + last_message + .content + .push(MessageContent::ToolUse(tool_use.clone())); } if !tool_use.is_input_complete { @@ -340,6 +349,7 @@ impl Thread { tool_choice: None, stop: Vec::new(), temperature: None, + thinking_allowed: false, } } @@ -373,8 +383,8 @@ where } /// Returns the JSON schema that describes the tool's input. - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema { - assistant_tools::root_schema_for::(format) + fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema { + schemars::schema_for!(Self::Input) } /// Runs the tool with the provided input. From afc8cf6098ed542e09ea488602830619b045367f Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 1 Aug 2025 22:48:35 -0600 Subject: [PATCH 16/25] Refactor agent2 tests structure - Move tests from thread/tests.rs to tests/mod.rs - Move test_tools from thread/tests/test_tools.rs to tests/test_tools.rs - Update imports and fix compilation errors in tests - Fix private field access by using public messages() method - Add necessary imports for test modules --- crates/agent2/src/agent2.rs | 3 +++ crates/agent2/src/{thread/tests.rs => tests/mod.rs} | 9 +++++---- crates/agent2/src/{thread => }/tests/test_tools.rs | 2 ++ 3 files changed, 10 insertions(+), 4 deletions(-) rename crates/agent2/src/{thread/tests.rs => tests/mod.rs} (97%) rename crates/agent2/src/{thread => }/tests/test_tools.rs (97%) diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index 577bf170dd..71d5ea711c 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -3,4 +3,7 @@ mod templates; mod thread; mod tools; +#[cfg(test)] +mod tests; + pub use thread::*; diff --git a/crates/agent2/src/thread/tests.rs b/crates/agent2/src/tests/mod.rs similarity index 97% rename from crates/agent2/src/thread/tests.rs rename to crates/agent2/src/tests/mod.rs index 9bbf95c6cf..42078a3ffa 100644 --- a/crates/agent2/src/thread/tests.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,5 +1,6 @@ use super::*; -use client::{proto::language_server_prompt_request, Client, UserStore}; +use crate::templates::Templates; +use client::{Client, UserStore}; use fs::FakeFs; use gpui::{AppContext, Entity, TestAppContext}; use language_model::{ @@ -27,7 +28,7 @@ async fn test_echo(cx: &mut TestAppContext) { .await; agent.update(cx, |agent, _cx| { assert_eq!( - agent.messages.last().unwrap().content, + agent.messages().last().unwrap().content, vec![MessageContent::Text("Hello".to_string())] ); }); @@ -74,7 +75,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { ); agent.update(cx, |agent, _cx| { assert!(agent - .messages + .messages() .last() .unwrap() .content @@ -170,7 +171,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { } agent.update(cx, |agent, _cx| { - let last_message = agent.messages.last().unwrap(); + let last_message = agent.messages().last().unwrap(); let text = last_message .content .iter() diff --git a/crates/agent2/src/thread/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs similarity index 97% rename from crates/agent2/src/thread/tests/test_tools.rs rename to crates/agent2/src/tests/test_tools.rs index 38da1bb2f2..43d0414499 100644 --- a/crates/agent2/src/thread/tests/test_tools.rs +++ b/crates/agent2/src/tests/test_tools.rs @@ -1,4 +1,6 @@ use super::*; +use anyhow::Result; +use gpui::{App, SharedString, Task}; /// A tool that echoes its input #[derive(JsonSchema, Serialize, Deserialize)] From 387ee1be8d9eab9f49abf29a0b29f3a750c94cd5 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 1 Aug 2025 22:53:37 -0600 Subject: [PATCH 17/25] Clean up warnings in agent2 - Remove underscore prefix from BasePrompt struct name - Remove unused imports and variables in tests - Fix unused parameter warning in async closure - Rename AgentToolErased to AnyAgentTool for clarity --- crates/agent2/src/prompts.rs | 4 ++-- crates/agent2/src/tests/mod.rs | 6 +----- crates/agent2/src/thread.rs | 8 ++++---- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/crates/agent2/src/prompts.rs b/crates/agent2/src/prompts.rs index ab30b117a8..015f56f4db 100644 --- a/crates/agent2/src/prompts.rs +++ b/crates/agent2/src/prompts.rs @@ -7,11 +7,11 @@ use gpui::{App, Entity}; use project::Project; #[allow(dead_code)] -struct _BasePrompt { +struct BasePrompt { project: Entity, } -impl Prompt for _BasePrompt { +impl Prompt for BasePrompt { fn render(&self, templates: &Templates, cx: &App) -> Result { BaseTemplate { os: std::env::consts::OS.to_string(), diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 42078a3ffa..ac790c8498 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,7 +1,6 @@ use super::*; use crate::templates::Templates; use client::{Client, UserStore}; -use fs::FakeFs; use gpui::{AppContext, Entity, TestAppContext}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -209,9 +208,6 @@ struct AgentTest { async fn setup(cx: &mut TestAppContext) -> AgentTest { cx.executor().allow_parking(); cx.update(settings::init); - let fs = FakeFs::new(cx.executor().clone()); - // let project = Project::test(fs.clone(), [], cx).await; - // let action_log = cx.new(|_| ActionLog::new(project.clone())); let templates = Templates::new(); let agent = cx.new(|_| Thread::new(templates)); @@ -236,7 +232,7 @@ async fn setup(cx: &mut TestAppContext) -> AgentTest { let provider = models.provider(&model.provider_id()).unwrap(); let authenticated = provider.authenticate(cx); - cx.spawn(async move |cx| { + cx.spawn(async move |_cx| { authenticated.await.unwrap(); model }) diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 85e2030f8b..bc88cf1d95 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -35,7 +35,7 @@ pub struct Thread { /// we run tools, report their results. running_turn: Option>, system_prompts: Vec>, - tools: BTreeMap>, + tools: BTreeMap>, templates: Arc, // project: Entity, // action_log: Entity, @@ -390,21 +390,21 @@ where /// Runs the tool with the provided input. fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task>; - fn erase(self) -> Arc { + fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) } } pub struct Erased(T); -pub trait AgentToolErased { +pub trait AnyAgentTool { fn name(&self) -> SharedString; fn description(&self, cx: &mut App) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task>; } -impl AgentToolErased for Erased> +impl AnyAgentTool for Erased> where T: AgentTool, { From 9e1c7fdfea5a069d69411213d72cdc2c6a3ce475 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 1 Aug 2025 23:14:21 -0600 Subject: [PATCH 18/25] Rename agent to thread in tests to avoid confusion Preparing for the introduction of a new Agent type that will implement the agent-client-protocol Agent trait. The existing Thread type represents individual conversation sessions, while Agent will manage multiple sessions. --- crates/agent2/src/tests/mod.rs | 74 +++++++++++++++++----------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index ac790c8498..f3d9a35c2b 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -17,17 +17,17 @@ use test_tools::*; #[gpui::test] async fn test_echo(cx: &mut TestAppContext) { - let AgentTest { model, agent, .. } = setup(cx).await; + let ThreadTest { model, thread, .. } = setup(cx).await; - let events = agent - .update(cx, |agent, cx| { - agent.send(model.clone(), "Testing: Reply with 'Hello'", cx) + let events = thread + .update(cx, |thread, cx| { + thread.send(model.clone(), "Testing: Reply with 'Hello'", cx) }) .collect() .await; - agent.update(cx, |agent, _cx| { + thread.update(cx, |thread, _cx| { assert_eq!( - agent.messages().last().unwrap().content, + thread.messages().last().unwrap().content, vec![MessageContent::Text("Hello".to_string())] ); }); @@ -36,13 +36,13 @@ async fn test_echo(cx: &mut TestAppContext) { #[gpui::test] async fn test_basic_tool_calls(cx: &mut TestAppContext) { - let AgentTest { model, agent, .. } = setup(cx).await; + let ThreadTest { model, thread, .. } = setup(cx).await; // Test a tool call that's likely to complete *before* streaming stops. - let events = agent - .update(cx, |agent, cx| { - agent.add_tool(EchoTool); - agent.send( + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(EchoTool); + thread.send( model.clone(), "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", cx, @@ -56,11 +56,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { ); // Test a tool calls that's likely to complete *after* streaming stops. - let events = agent - .update(cx, |agent, cx| { - agent.remove_tool(&AgentTool::name(&EchoTool)); - agent.add_tool(DelayTool); - agent.send( + let events = thread + .update(cx, |thread, cx| { + thread.remove_tool(&AgentTool::name(&EchoTool)); + thread.add_tool(DelayTool); + thread.send( model.clone(), "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", cx, @@ -72,8 +72,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { stop_events(events), vec![StopReason::ToolUse, StopReason::EndTurn] ); - agent.update(cx, |agent, _cx| { - assert!(agent + thread.update(cx, |thread, _cx| { + assert!(thread .messages() .last() .unwrap() @@ -91,20 +91,20 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { #[gpui::test] async fn test_streaming_tool_calls(cx: &mut TestAppContext) { - let AgentTest { model, agent, .. } = setup(cx).await; + let ThreadTest { model, thread, .. } = setup(cx).await; // Test a tool call that's likely to complete *before* streaming stops. - let mut events = agent.update(cx, |agent, cx| { - agent.add_tool(WordListTool); - agent.send(model.clone(), "Test the word_list tool.", cx) + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(WordListTool); + thread.send(model.clone(), "Test the word_list tool.", cx) }); let mut saw_partial_tool_use = false; while let Some(event) = events.next().await { if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event { - agent.update(cx, |agent, _cx| { - // Look for a tool use in the agent's last message - let last_content = agent.messages().last().unwrap().content.last().unwrap(); + thread.update(cx, |thread, _cx| { + // Look for a tool use in the thread's last message + let last_content = thread.messages().last().unwrap().content.last().unwrap(); if let MessageContent::ToolUse(last_tool_use) = last_content { assert_eq!(last_tool_use.name.as_ref(), "word_list"); if tool_use_event.is_input_complete { @@ -138,13 +138,13 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { #[gpui::test] async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { - let AgentTest { model, agent, .. } = setup(cx).await; + let ThreadTest { model, thread, .. } = setup(cx).await; // Test concurrent tool calls with different delay times - let events = agent - .update(cx, |agent, cx| { - agent.add_tool(DelayTool); - agent.send( + let events = thread + .update(cx, |thread, cx| { + thread.add_tool(DelayTool); + thread.send( model.clone(), "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", cx, @@ -169,8 +169,8 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { panic!("Expected either 1 or 2 tool uses followed by end turn"); } - agent.update(cx, |agent, _cx| { - let last_message = agent.messages().last().unwrap(); + thread.update(cx, |thread, _cx| { + let last_message = thread.messages().last().unwrap(); let text = last_message .content .iter() @@ -200,16 +200,16 @@ fn stop_events( .collect() } -struct AgentTest { +struct ThreadTest { model: Arc, - agent: Entity, + thread: Entity, } -async fn setup(cx: &mut TestAppContext) -> AgentTest { +async fn setup(cx: &mut TestAppContext) -> ThreadTest { cx.executor().allow_parking(); cx.update(settings::init); let templates = Templates::new(); - let agent = cx.new(|_| Thread::new(templates)); + let thread = cx.new(|_| Thread::new(templates)); let model = cx .update(|cx| { @@ -239,7 +239,7 @@ async fn setup(cx: &mut TestAppContext) -> AgentTest { }) .await; - AgentTest { model, agent } + ThreadTest { model, thread } } #[cfg(test)] From 27877325bccbcca770f92071fee19cf5f169145e Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 1 Aug 2025 23:22:21 -0600 Subject: [PATCH 19/25] Implement agent-client-protocol Agent trait for agent2 Added Agent struct that implements the acp::Agent trait with: - Complete: initialize (protocol handshake) and authenticate (stub auth) - Partial: new_session (creates ID but needs GPUI context for Thread) - Partial: cancelled (removes session but needs GPUI cleanup) - Stub: load_session and prompt (need GPUI context integration) The implementation uses RefCell for session management since trait methods take &self, and Cell for simple authentication state. Templates are Arc'd for potential Send requirements. Next steps: - Integrate GPUI context for Thread creation/management - Implement content type conversions between acp and agent2 - Add proper session persistence for load_session - Stream responses back through the protocol --- Cargo.lock | 3 + crates/agent2/Cargo.toml | 3 + crates/agent2/src/agent.rs | 173 ++++++++++++++++++++++++++++++++++++ crates/agent2/src/agent2.rs | 2 + 4 files changed, 181 insertions(+) create mode 100644 crates/agent2/src/agent.rs diff --git a/Cargo.lock b/Cargo.lock index d2d4798880..8aaaff3947 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -152,6 +152,7 @@ dependencies = [ name = "agent2" version = "0.1.0" dependencies = [ + "agent-client-protocol", "anyhow", "assistant_tool", "assistant_tools", @@ -168,6 +169,7 @@ dependencies = [ "handlebars 4.5.0", "language_model", "language_models", + "log", "parking_lot", "project", "reqwest_client", @@ -179,6 +181,7 @@ dependencies = [ "smol", "thiserror 2.0.12", "util", + "uuid", "worktree", ] diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index c1ce775ae6..30e0f36c0d 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -12,6 +12,7 @@ path = "src/agent2.rs" workspace = true [dependencies] +agent-client-protocol.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true @@ -24,6 +25,7 @@ gpui.workspace = true handlebars = { workspace = true, features = ["rust-embed"] } language_model.workspace = true language_models.workspace = true +log.workspace = true parking_lot.workspace = true project.workspace = true rust-embed.workspace = true @@ -34,6 +36,7 @@ settings.workspace = true smol.workspace = true thiserror.workspace = true util.workspace = true +uuid.workspace = true worktree.workspace = true [dev-dependencies] diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs new file mode 100644 index 0000000000..c608069586 --- /dev/null +++ b/crates/agent2/src/agent.rs @@ -0,0 +1,173 @@ +//! Agent implementation for the agent-client-protocol +//! +//! Implementation Status: +//! - [x] initialize: Complete - Basic protocol handshake +//! - [x] authenticate: Complete - Accepts any auth (stub) +//! - [~] new_session: Partial - Creates session ID but Thread creation needs GPUI context +//! - [~] load_session: Stub - Returns not implemented +//! - [ ] prompt: Stub - Needs GPUI context and type conversions +//! - [~] cancelled: Partial - Removes session from map but needs GPUI cleanup + +use agent_client_protocol as acp; +use gpui::Entity; +use std::cell::{Cell, RefCell}; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::{templates::Templates, Thread}; + +pub struct Agent { + /// Session ID -> Thread entity mapping + sessions: RefCell>>, + /// Shared templates for all threads + templates: Arc, + /// Current protocol version we support + protocol_version: acp::ProtocolVersion, + /// Authentication state + authenticated: Cell, +} + +impl Agent { + pub fn new(templates: Arc) -> Self { + Self { + sessions: RefCell::new(HashMap::new()), + templates, + protocol_version: acp::VERSION, + authenticated: Cell::new(false), + } + } +} + +impl acp::Agent for Agent { + /// COMPLETE: Initialize handshake with client + async fn initialize( + &self, + arguments: acp::InitializeRequest, + ) -> Result { + // For now, we just use the client's requested version + let response_version = arguments.protocol_version.clone(); + + Ok(acp::InitializeResponse { + protocol_version: response_version, + agent_capabilities: acp::AgentCapabilities::default(), + auth_methods: vec![ + // STUB: No authentication required for now + acp::AuthMethod { + id: acp::AuthMethodId("none".into()), + label: "No Authentication".to_string(), + description: Some("No authentication required".to_string()), + }, + ], + }) + } + + /// COMPLETE: Handle authentication (currently just accepts any auth) + async fn authenticate(&self, _arguments: acp::AuthenticateRequest) -> Result<(), acp::Error> { + // STUB: Accept any authentication method for now + self.authenticated.set(true); + Ok(()) + } + + /// PARTIAL: Create a new session + async fn new_session( + &self, + arguments: acp::NewSessionRequest, + ) -> Result { + // Check if authenticated + if !self.authenticated.get() { + return Ok(acp::NewSessionResponse { session_id: None }); + } + + // STUB: Generate a simple session ID + let session_id = acp::SessionId(format!("session-{}", uuid::Uuid::new_v4()).into()); + + // Create a new Thread for this session + // TODO: This needs to be done on the main thread with proper GPUI context + // For now, we'll return the session ID and expect the actual Thread creation + // to happen when we have access to a GPUI context + + // STUB: MCP server support not implemented + if !arguments.mcp_servers.is_empty() { + log::warn!("MCP servers requested but not yet supported"); + } + + Ok(acp::NewSessionResponse { + session_id: Some(session_id), + }) + } + + /// STUB: Load existing session + async fn load_session( + &self, + _arguments: acp::LoadSessionRequest, + ) -> Result { + // STUB: Session persistence not implemented + Ok(acp::LoadSessionResponse { + auth_required: !self.authenticated.get(), + auth_methods: if self.authenticated.get() { + vec![] + } else { + vec![acp::AuthMethod { + id: acp::AuthMethodId("none".into()), + label: "No Authentication".to_string(), + description: Some("No authentication required".to_string()), + }] + }, + }) + } + + /// STUB: Handle prompts + async fn prompt(&self, arguments: acp::PromptRequest) -> Result<(), acp::Error> { + // TODO: This needs to be implemented with proper GPUI context access + // The implementation would: + // 1. Look up the Thread for this session + // 2. Convert acp::ContentBlock to agent2 message format + // 3. Call thread.send() with the converted message + // 4. Stream responses back to the client + + let _session_id = arguments.session_id; + let _prompt = arguments.prompt; + + // STUB: Just acknowledge receipt for now + log::info!("Received prompt for session: {}", _session_id.0); + + Err(acp::Error::internal_error().with_data("Prompt handling not yet implemented")) + } + + /// PARTIAL: Handle cancellation + async fn cancelled(&self, args: acp::CancelledNotification) -> Result<(), acp::Error> { + // Remove the session from our map + let removed = self.sessions.borrow_mut().remove(&args.session_id); + + if removed.is_some() { + // TODO: Properly clean up the Thread entity when we have GPUI context + log::info!("Session {} cancelled and removed", args.session_id.0); + Ok(()) + } else { + Err(acp::Error::invalid_request() + .with_data(format!("Session {} not found", args.session_id.0))) + } + } +} + +// Helper functions for type conversions between acp and agent2 types + +/// Convert acp::ContentBlock to agent2 message format +/// STUB: Needs implementation +fn convert_content_block(_block: acp::ContentBlock) -> String { + // TODO: Implement proper conversion + // This would handle: + // - Text content + // - Resource links + // - Images + // - Audio + // - Other content types + "".to_string() +} + +/// Convert agent2 messages to acp format for responses +/// STUB: Needs implementation +fn convert_to_acp_content(_content: &str) -> Vec { + // TODO: Implement proper conversion + vec![] +} diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index 71d5ea711c..66ed32eccd 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,3 +1,4 @@ +mod agent; mod prompts; mod templates; mod thread; @@ -6,4 +7,5 @@ mod tools; #[cfg(test)] mod tests; +pub use agent::*; pub use thread::*; From 5d621bef789a7e534d1abb8ac9e4653955040d24 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sat, 2 Aug 2025 00:15:22 -0600 Subject: [PATCH 20/25] WIP --- Cargo.lock | 1 + crates/agent2/Cargo.toml | 1 + crates/agent2/src/agent.rs | 218 +++++++++++++++---------------------- 3 files changed, 87 insertions(+), 133 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8aaaff3947..705ed44d38 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -152,6 +152,7 @@ dependencies = [ name = "agent2" version = "0.1.0" dependencies = [ + "acp_thread", "agent-client-protocol", "anyhow", "assistant_tool", diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 30e0f36c0d..72f5f14008 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -12,6 +12,7 @@ path = "src/agent2.rs" workspace = true [dependencies] +acp_thread.workspace = true agent-client-protocol.workspace = true anyhow.workspace = true assistant_tool.workspace = true diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index c608069586..c1c28ad41b 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,173 +1,125 @@ -//! Agent implementation for the agent-client-protocol -//! -//! Implementation Status: -//! - [x] initialize: Complete - Basic protocol handshake -//! - [x] authenticate: Complete - Accepts any auth (stub) -//! - [~] new_session: Partial - Creates session ID but Thread creation needs GPUI context -//! - [~] load_session: Stub - Returns not implemented -//! - [ ] prompt: Stub - Needs GPUI context and type conversions -//! - [~] cancelled: Partial - Removes session from map but needs GPUI cleanup - use agent_client_protocol as acp; -use gpui::Entity; -use std::cell::{Cell, RefCell}; +use anyhow::Result; +use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use project::Project; use std::collections::HashMap; +use std::path::Path; +use std::rc::Rc; use std::sync::Arc; use crate::{templates::Templates, Thread}; pub struct Agent { /// Session ID -> Thread entity mapping - sessions: RefCell>>, + sessions: HashMap>, /// Shared templates for all threads templates: Arc, - /// Current protocol version we support - protocol_version: acp::ProtocolVersion, - /// Authentication state - authenticated: Cell, } impl Agent { pub fn new(templates: Arc) -> Self { Self { - sessions: RefCell::new(HashMap::new()), + sessions: HashMap::new(), templates, - protocol_version: acp::VERSION, - authenticated: Cell::new(false), } } } -impl acp::Agent for Agent { - /// COMPLETE: Initialize handshake with client - async fn initialize( - &self, - arguments: acp::InitializeRequest, - ) -> Result { - // For now, we just use the client's requested version - let response_version = arguments.protocol_version.clone(); +/// Wrapper struct that implements the AgentConnection trait +pub struct AgentConnection(pub Entity); - Ok(acp::InitializeResponse { - protocol_version: response_version, - agent_capabilities: acp::AgentCapabilities::default(), - auth_methods: vec![ - // STUB: No authentication required for now - acp::AuthMethod { - id: acp::AuthMethodId("none".into()), - label: "No Authentication".to_string(), - description: Some("No authentication required".to_string()), - }, - ], +impl acp_thread::AgentConnection for AgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let _cwd = cwd.to_owned(); + let agent = self.0.clone(); + + cx.spawn(async move |cx| { + // Create Thread and store in Agent + let (session_id, _thread) = + agent.update(cx, |agent, cx: &mut gpui::Context| { + let thread = cx.new(|_| Thread::new(agent.templates.clone())); + let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + agent.sessions.insert(session_id.clone(), thread.clone()); + (session_id, thread) + })?; + + // Create AcpThread + let acp_thread = cx.update(|cx| { + cx.new(|cx| acp_thread::AcpThread::new("agent2", self, project, session_id, cx)) + })?; + + Ok(acp_thread) }) } - /// COMPLETE: Handle authentication (currently just accepts any auth) - async fn authenticate(&self, _arguments: acp::AuthenticateRequest) -> Result<(), acp::Error> { - // STUB: Accept any authentication method for now - self.authenticated.set(true); - Ok(()) + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] // No auth for in-process } - /// PARTIAL: Create a new session - async fn new_session( - &self, - arguments: acp::NewSessionRequest, - ) -> Result { - // Check if authenticated - if !self.authenticated.get() { - return Ok(acp::NewSessionResponse { session_id: None }); - } - - // STUB: Generate a simple session ID - let session_id = acp::SessionId(format!("session-{}", uuid::Uuid::new_v4()).into()); - - // Create a new Thread for this session - // TODO: This needs to be done on the main thread with proper GPUI context - // For now, we'll return the session ID and expect the actual Thread creation - // to happen when we have access to a GPUI context - - // STUB: MCP server support not implemented - if !arguments.mcp_servers.is_empty() { - log::warn!("MCP servers requested but not yet supported"); - } - - Ok(acp::NewSessionResponse { - session_id: Some(session_id), - }) + fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) } - /// STUB: Load existing session - async fn load_session( - &self, - _arguments: acp::LoadSessionRequest, - ) -> Result { - // STUB: Session persistence not implemented - Ok(acp::LoadSessionResponse { - auth_required: !self.authenticated.get(), - auth_methods: if self.authenticated.get() { - vec![] - } else { - vec![acp::AuthMethod { - id: acp::AuthMethodId("none".into()), - label: "No Authentication".to_string(), - description: Some("No authentication required".to_string()), - }] - }, - }) - } + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { + let session_id = params.session_id.clone(); + let agent = self.0.clone(); - /// STUB: Handle prompts - async fn prompt(&self, arguments: acp::PromptRequest) -> Result<(), acp::Error> { - // TODO: This needs to be implemented with proper GPUI context access - // The implementation would: - // 1. Look up the Thread for this session - // 2. Convert acp::ContentBlock to agent2 message format - // 3. Call thread.send() with the converted message - // 4. Stream responses back to the client + cx.spawn(|cx| async move { + // Get thread + let thread: Entity = agent + .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())? + .ok_or_else(|| anyhow::anyhow!("Session not found"))?; - let _session_id = arguments.session_id; - let _prompt = arguments.prompt; + // Convert prompt to message + let message = convert_prompt_to_message(params.prompt); - // STUB: Just acknowledge receipt for now - log::info!("Received prompt for session: {}", _session_id.0); + // TODO: Get model from somewhere - for now use a placeholder + log::warn!("Model selection not implemented - need to get from UI context"); - Err(acp::Error::internal_error().with_data("Prompt handling not yet implemented")) - } + // Send to thread + // thread.update(&mut cx, |thread, cx| { + // thread.send(model, message, cx) + // })?; - /// PARTIAL: Handle cancellation - async fn cancelled(&self, args: acp::CancelledNotification) -> Result<(), acp::Error> { - // Remove the session from our map - let removed = self.sessions.borrow_mut().remove(&args.session_id); - - if removed.is_some() { - // TODO: Properly clean up the Thread entity when we have GPUI context - log::info!("Session {} cancelled and removed", args.session_id.0); Ok(()) - } else { - Err(acp::Error::invalid_request() - .with_data(format!("Session {} not found", args.session_id.0))) - } + }) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + self.0.update(cx, |agent, _cx| { + agent.sessions.remove(session_id); + }); } } -// Helper functions for type conversions between acp and agent2 types +/// Convert ACP content blocks to a message string +fn convert_prompt_to_message(blocks: Vec) -> String { + let mut message = String::new(); -/// Convert acp::ContentBlock to agent2 message format -/// STUB: Needs implementation -fn convert_content_block(_block: acp::ContentBlock) -> String { - // TODO: Implement proper conversion - // This would handle: - // - Text content - // - Resource links - // - Images - // - Audio - // - Other content types - "".to_string() -} + for block in blocks { + match block { + acp::ContentBlock::Text(text) => { + message.push_str(&text.text); + } + acp::ContentBlock::ResourceLink(link) => { + message.push_str(&format!(" @{} ", link.uri)); + } + acp::ContentBlock::Image(_) => { + message.push_str(" [image] "); + } + acp::ContentBlock::Audio(_) => { + message.push_str(" [audio] "); + } + acp::ContentBlock::Resource(resource) => { + message.push_str(&format!(" [resource: {:?}] ", resource.resource)); + } + } + } -/// Convert agent2 messages to acp format for responses -/// STUB: Needs implementation -fn convert_to_acp_content(_content: &str) -> Vec { - // TODO: Implement proper conversion - vec![] + message } From a4fe8c69722c06f9494ccfcd1909e1f463c3e34d Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sat, 2 Aug 2025 08:28:37 -0600 Subject: [PATCH 21/25] Add ModelSelector capability to AgentConnection - Add ModelSelector trait to acp_thread crate with list_models, select_model, and selected_model methods - Extend AgentConnection trait with optional model_selector() method returning Option> - Implement ModelSelector for agent2's AgentConnection using LanguageModelRegistry - Make selected_model field mandatory on Thread struct - Update Thread::new to require a default_model parameter - Update agent2 to fetch default model from registry when creating threads - Fix prompt method to use the thread's selected model directly - All methods use &mut AsyncApp for async-friendly operations --- Cargo.lock | 1 + crates/acp_thread/Cargo.toml | 1 + crates/acp_thread/src/connection.rs | 58 ++++++++++++++++++- crates/agent2/src/agent.rs | 88 ++++++++++++++++++++++++++--- crates/agent2/src/tests/mod.rs | 3 +- crates/agent2/src/thread.rs | 4 +- 6 files changed, 144 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 705ed44d38..d81effe7a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,6 +19,7 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "language_model", "markdown", "project", "serde", diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 011f26f364..308756b038 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -26,6 +26,7 @@ futures.workspace = true gpui.workspace = true itertools.workspace = true language.workspace = true +language_model.workspace = true markdown.workspace = true project.workspace = true serde.workspace = true diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 929500a67b..b99e4949d8 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,13 +1,61 @@ -use std::{error::Error, fmt, path::Path, rc::Rc}; +use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use agent_client_protocol::{self as acp}; use anyhow::Result; use gpui::{AsyncApp, Entity, Task}; +use language_model::LanguageModel; use project::Project; use ui::App; use crate::AcpThread; +/// Trait for agents that support listing, selecting, and querying language models. +/// +/// This is an optional capability; agents indicate support via [AgentConnection::model_selector]. +pub trait ModelSelector: 'static { + /// Lists all available language models for this agent. + /// + /// # Parameters + /// - `cx`: The GPUI app context for async operations and global access. + /// + /// # Returns + /// A task resolving to the list of models or an error (e.g., if no models are configured). + fn list_models(&self, cx: &mut AsyncApp) -> Task>>>; + + /// Selects a model for a specific session (thread). + /// + /// This sets the default model for future interactions in the session. + /// If the session doesn't exist or the model is invalid, it returns an error. + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to apply the model to. + /// - `model`: The model to select (should be one from [list_models]). + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to `Ok(())` on success or an error. + fn select_model( + &self, + session_id: &acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task>; + + /// Retrieves the currently selected model for a specific session (thread). + /// + /// # Parameters + /// - `session_id`: The ID of the session (thread) to query. + /// - `cx`: The GPUI app context. + /// + /// # Returns + /// A task resolving to the selected model (always set) or an error (e.g., session not found). + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>>; +} + pub trait AgentConnection { fn new_thread( self: Rc, @@ -23,6 +71,14 @@ pub trait AgentConnection { fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task>; fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + + /// Returns this agent as an [Rc] if the model selection capability is supported. + /// + /// If the agent does not support model selection, returns [None]. + /// This allows sharing the selector in UI components. + fn model_selector(&self) -> Option> { + None // Default impl for agents that don't support it + } } #[derive(Debug)] diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index c1c28ad41b..f10738313e 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,6 +1,8 @@ +use acp_thread::ModelSelector; use agent_client_protocol as acp; use anyhow::Result; use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use language_model::{LanguageModel, LanguageModelRegistry}; use project::Project; use std::collections::HashMap; use std::path::Path; @@ -26,8 +28,67 @@ impl Agent { } /// Wrapper struct that implements the AgentConnection trait +#[derive(Clone)] pub struct AgentConnection(pub Entity); +impl ModelSelector for AgentConnection { + fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { + let result = cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let models = registry.available_models(cx).collect::>(); + if models.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(models) + } + }); + Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e)))) + } + + fn select_model( + &self, + session_id: &acp::SessionId, + model: Arc, + cx: &mut AsyncApp, + ) -> Task> { + let agent = self.0.clone(); + let result = agent.update(cx, |agent, cx| { + if let Some(thread) = agent.sessions.get(session_id) { + thread.update(cx, |thread, _| { + thread.selected_model = model; + }); + Ok(()) + } else { + Err(anyhow::anyhow!("Session not found")) + } + }); + Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e)))) + } + + fn selected_model( + &self, + session_id: &acp::SessionId, + cx: &mut AsyncApp, + ) -> Task>> { + let agent = self.0.clone(); + let thread_result = agent + .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned()) + .ok() + .flatten() + .ok_or_else(|| anyhow::anyhow!("Session not found")); + + match thread_result { + Ok(thread) => { + let selected = thread + .read_with(cx, |thread, _| thread.selected_model.clone()) + .unwrap_or_else(|e| panic!("Failed to read thread: {}", e)); + Task::ready(Ok(selected)) + } + Err(e) => Task::ready(Err(e)), + } + } +} + impl acp_thread::AgentConnection for AgentConnection { fn new_thread( self: Rc, @@ -42,7 +103,13 @@ impl acp_thread::AgentConnection for AgentConnection { // Create Thread and store in Agent let (session_id, _thread) = agent.update(cx, |agent, cx: &mut gpui::Context| { - let thread = cx.new(|_| Thread::new(agent.templates.clone())); + // Fetch default model + let default_model = LanguageModelRegistry::read_global(cx) + .available_models(cx) + .next() + .unwrap_or_else(|| panic!("No default model available")); + + let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model)); let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); agent.sessions.insert(session_id.clone(), thread.clone()); (session_id, thread) @@ -50,7 +117,9 @@ impl acp_thread::AgentConnection for AgentConnection { // Create AcpThread let acp_thread = cx.update(|cx| { - cx.new(|cx| acp_thread::AcpThread::new("agent2", self, project, session_id, cx)) + cx.new(|cx| { + acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx) + }) })?; Ok(acp_thread) @@ -65,11 +134,15 @@ impl acp_thread::AgentConnection for AgentConnection { Task::ready(Ok(())) } + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let session_id = params.session_id.clone(); let agent = self.0.clone(); - cx.spawn(|cx| async move { + cx.spawn(async move |cx| { // Get thread let thread: Entity = agent .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())? @@ -78,13 +151,12 @@ impl acp_thread::AgentConnection for AgentConnection { // Convert prompt to message let message = convert_prompt_to_message(params.prompt); - // TODO: Get model from somewhere - for now use a placeholder - log::warn!("Model selection not implemented - need to get from UI context"); + // Get model using the ModelSelector capability (always available for agent2) + // Get the selected model from the thread directly + let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; // Send to thread - // thread.update(&mut cx, |thread, cx| { - // thread.send(model, message, cx) - // })?; + thread.update(cx, |thread, cx| thread.send(model, message, cx))?; Ok(()) }) diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index f3d9a35c2b..f7dc9055f6 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -209,7 +209,6 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest { cx.executor().allow_parking(); cx.update(settings::init); let templates = Templates::new(); - let thread = cx.new(|_| Thread::new(templates)); let model = cx .update(|cx| { @@ -239,6 +238,8 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest { }) .await; + let thread = cx.new(|_| Thread::new(templates, model.clone())); + ThreadTest { model, thread } } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index bc88cf1d95..758e940269 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -37,12 +37,13 @@ pub struct Thread { system_prompts: Vec>, tools: BTreeMap>, templates: Arc, + pub selected_model: Arc, // project: Entity, // action_log: Entity, } impl Thread { - pub fn new(templates: Arc) -> Self { + pub fn new(templates: Arc, default_model: Arc) -> Self { Self { messages: Vec::new(), completion_mode: CompletionMode::Normal, @@ -50,6 +51,7 @@ impl Thread { running_turn: None, tools: BTreeMap::default(), templates, + selected_model: default_model, } } From 604a88f6e35697c1219c27ee92e7d7fa7169741e Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sat, 2 Aug 2025 08:45:51 -0600 Subject: [PATCH 22/25] Add comprehensive test for AgentConnection with ModelSelector - Add public session_id() method to AcpThread to enable testing - Fix ModelSelector methods to use async move closures properly to avoid borrow conflicts - Add test_agent_connection that verifies: - Model selector is available for agent2 - Can list available models - Can create threads with default model - Can query selected model for a session - Can send prompts using the selected model - Can cancel sessions - Handles errors for invalid sessions - Remove unnecessary mut keywords from async closures --- crates/acp_thread/src/acp_thread.rs | 4 + crates/agent2/src/agent.rs | 68 +++++++-------- crates/agent2/src/tests/mod.rs | 128 +++++++++++++++++++++++++++- 3 files changed, 163 insertions(+), 37 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d10fecdb28..c42c155e89 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -656,6 +656,10 @@ impl AcpThread { &self.entries } + pub fn session_id(&self) -> &acp::SessionId { + &self.session_id + } + pub fn status(&self) -> ThreadStatus { if self.send_task.is_some() { if self.waiting_for_tool_confirmation() { diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index f10738313e..abd23de375 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -33,16 +33,17 @@ pub struct AgentConnection(pub Entity); impl ModelSelector for AgentConnection { fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { - let result = cx.update(|cx| { - let registry = LanguageModelRegistry::read_global(cx); - let models = registry.available_models(cx).collect::>(); - if models.is_empty() { - Err(anyhow::anyhow!("No models available")) - } else { - Ok(models) - } - }); - Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e)))) + cx.spawn(async move |cx| { + cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let models = registry.available_models(cx).collect::>(); + if models.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(models) + } + })? + }) } fn select_model( @@ -52,17 +53,19 @@ impl ModelSelector for AgentConnection { cx: &mut AsyncApp, ) -> Task> { let agent = self.0.clone(); - let result = agent.update(cx, |agent, cx| { - if let Some(thread) = agent.sessions.get(session_id) { - thread.update(cx, |thread, _| { - thread.selected_model = model; - }); - Ok(()) - } else { - Err(anyhow::anyhow!("Session not found")) - } - }); - Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e)))) + let session_id = session_id.clone(); + cx.spawn(async move |cx| { + agent.update(cx, |agent, cx| { + if let Some(thread) = agent.sessions.get(&session_id) { + thread.update(cx, |thread, _| { + thread.selected_model = model; + }); + Ok(()) + } else { + Err(anyhow::anyhow!("Session not found")) + } + })? + }) } fn selected_model( @@ -71,21 +74,14 @@ impl ModelSelector for AgentConnection { cx: &mut AsyncApp, ) -> Task>> { let agent = self.0.clone(); - let thread_result = agent - .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned()) - .ok() - .flatten() - .ok_or_else(|| anyhow::anyhow!("Session not found")); - - match thread_result { - Ok(thread) => { - let selected = thread - .read_with(cx, |thread, _| thread.selected_model.clone()) - .unwrap_or_else(|e| panic!("Failed to read thread: {}", e)); - Task::ready(Ok(selected)) - } - Err(e) => Task::ready(Err(e)), - } + let session_id = session_id.clone(); + cx.spawn(async move |cx| { + let thread = agent + .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())? + .ok_or_else(|| anyhow::anyhow!("Session not found"))?; + let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + Ok(selected) + }) } } diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index f7dc9055f6..a628658b22 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,16 +1,19 @@ use super::*; use crate::templates::Templates; +use acp_thread::AgentConnection as _; +use agent_client_protocol as acp; use client::{Client, UserStore}; use gpui::{AppContext, Entity, TestAppContext}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry, MessageContent, StopReason, }; +use project::Project; use reqwest_client::ReqwestClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use smol::stream::StreamExt; -use std::{sync::Arc, time::Duration}; +use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; mod test_tools; use test_tools::*; @@ -187,6 +190,129 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_agent_connection(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + cx.update(settings::init); + let templates = Templates::new(); + + // Initialize language model system with test provider + cx.update(|cx| { + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + // Initialize project settings + Project::init_settings(cx); + + // Use test registry with fake provider + LanguageModelRegistry::test(cx); + }); + + // Create agent and connection + let agent = cx.new(|_| Agent::new(templates.clone())); + let connection = AgentConnection(agent.clone()); + + // Test model_selector returns Some + let selector_opt = connection.model_selector(); + assert!( + selector_opt.is_some(), + "agent2 should always support ModelSelector" + ); + let selector = selector_opt.unwrap(); + + // Test list_models + let listed_models = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.list_models(&mut async_cx) + }) + .await + .expect("list_models should succeed"); + assert!(!listed_models.is_empty(), "should have at least one model"); + assert_eq!(listed_models[0].id().0, "fake"); + + // Create a project for new_thread + let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); + let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + + // Create a thread using new_thread + let cwd = Path::new("/test"); + let connection_rc = Rc::new(connection.clone()); + let acp_thread = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + connection_rc.new_thread(project, cwd, &mut async_cx) + }) + .await + .expect("new_thread should succeed"); + + // Get the session_id from the AcpThread + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + + // Test selected_model returns the default + let selected = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&session_id, &mut async_cx) + }) + .await + .expect("selected_model should succeed"); + assert_eq!(selected.id().0, "fake", "should return default model"); + + // The thread was created via prompt with the default model + // We can verify it through selected_model + + // Test prompt uses the selected model + let prompt_request = acp::PromptRequest { + session_id: session_id.clone(), + prompt: vec![acp::ContentBlock::Text(acp::TextContent { + text: "Test prompt".into(), + annotations: None, + })], + }; + + cx.update(|cx| connection.prompt(prompt_request, cx)) + .await + .expect("prompt should succeed"); + + // The prompt was sent successfully + + // Test cancel + cx.update(|cx| connection.cancel(&session_id, cx)); + + // After cancel, selected_model should fail + let result = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&session_id, &mut async_cx) + }) + .await; + assert!(result.is_err(), "selected_model should fail after cancel"); + + // Test error case: invalid session + let invalid_session = acp::SessionId("invalid".into()); + let result = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&invalid_session, &mut async_cx) + }) + .await; + assert!(result.is_err(), "should fail for invalid session"); + if let Err(e) = result { + assert!( + e.to_string().contains("Session not found"), + "should have correct error message" + ); + } +} + /// Filters out the stop events for asserting against in tests fn stop_events( result_events: Vec>, From 4f2d6a9ea97ba9ba9c460166e4f4ee7dc08e7b64 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sat, 2 Aug 2025 08:59:16 -0600 Subject: [PATCH 23/25] Rename Agent to NativeAgent and AgentConnection to NativeAgentConnection - Renamed Agent struct to NativeAgent to better reflect its native implementation - Renamed AgentConnection to NativeAgentConnection for consistency - Updated all references and implementations - Bumped agent-client-protocol version to 0.0.14 --- Cargo.lock | 2 +- crates/agent2/src/agent.rs | 12 ++++++------ crates/agent2/src/tests/mod.rs | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d81effe7a2..494e8cd744 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,7 +138,7 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.13" +version = "0.0.14" dependencies = [ "anyhow", "futures 0.3.31", diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index abd23de375..60c154971b 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -11,14 +11,14 @@ use std::sync::Arc; use crate::{templates::Templates, Thread}; -pub struct Agent { +pub struct NativeAgent { /// Session ID -> Thread entity mapping sessions: HashMap>, /// Shared templates for all threads templates: Arc, } -impl Agent { +impl NativeAgent { pub fn new(templates: Arc) -> Self { Self { sessions: HashMap::new(), @@ -29,9 +29,9 @@ impl Agent { /// Wrapper struct that implements the AgentConnection trait #[derive(Clone)] -pub struct AgentConnection(pub Entity); +pub struct NativeAgentConnection(pub Entity); -impl ModelSelector for AgentConnection { +impl ModelSelector for NativeAgentConnection { fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { cx.spawn(async move |cx| { cx.update(|cx| { @@ -85,7 +85,7 @@ impl ModelSelector for AgentConnection { } } -impl acp_thread::AgentConnection for AgentConnection { +impl acp_thread::AgentConnection for NativeAgentConnection { fn new_thread( self: Rc, project: Entity, @@ -98,7 +98,7 @@ impl acp_thread::AgentConnection for AgentConnection { cx.spawn(async move |cx| { // Create Thread and store in Agent let (session_id, _thread) = - agent.update(cx, |agent, cx: &mut gpui::Context| { + agent.update(cx, |agent, cx: &mut gpui::Context| { // Fetch default model let default_model = LanguageModelRegistry::read_global(cx) .available_models(cx) diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index a628658b22..ced3c239f0 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -216,8 +216,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) { }); // Create agent and connection - let agent = cx.new(|_| Agent::new(templates.clone())); - let connection = AgentConnection(agent.clone()); + let agent = cx.new(|_| NativeAgent::new(templates.clone())); + let connection = NativeAgentConnection(agent.clone()); // Test model_selector returns Some let selector_opt = connection.model_selector(); From bc1f861d3fe2f66102f1a9f5562f32cf716c81c7 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sat, 2 Aug 2025 09:17:26 -0600 Subject: [PATCH 24/25] Add Native Agent to UI and implement NativeAgentServer - Created NativeAgentServer that implements AgentServer trait - Added NativeAgent to ExternalAgent enum - Added Native Agent option to both the + menu and empty state view - Added necessary dependencies (agent_servers, ui) to agent2 crate - Added agent2 dependency to agent_ui crate - Temporarily removed feature flag check for testing --- Cargo.lock | 3 + crates/agent2/Cargo.toml | 2 + crates/agent2/src/agent2.rs | 2 + crates/agent2/src/native_agent_server.rs | 51 ++++++ crates/agent_ui/Cargo.toml | 1 + crates/agent_ui/src/agent_panel.rs | 216 +++++++++++++---------- crates/agent_ui/src/agent_ui.rs | 2 + 7 files changed, 188 insertions(+), 89 deletions(-) create mode 100644 crates/agent2/src/native_agent_server.rs diff --git a/Cargo.lock b/Cargo.lock index 494e8cd744..6daa01dca3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -155,6 +155,7 @@ version = "0.1.0" dependencies = [ "acp_thread", "agent-client-protocol", + "agent_servers", "anyhow", "assistant_tool", "assistant_tools", @@ -182,6 +183,7 @@ dependencies = [ "settings", "smol", "thiserror 2.0.12", + "ui", "util", "uuid", "worktree", @@ -250,6 +252,7 @@ dependencies = [ "acp_thread", "agent", "agent-client-protocol", + "agent2", "agent_servers", "agent_settings", "ai_onboarding", diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 72f5f14008..b2e83e2252 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] acp_thread.workspace = true agent-client-protocol.workspace = true +agent_servers.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true @@ -36,6 +37,7 @@ serde_json.workspace = true settings.workspace = true smol.workspace = true thiserror.workspace = true +ui.workspace = true util.workspace = true uuid.workspace = true worktree.workspace = true diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index 66ed32eccd..aa665fe313 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,4 +1,5 @@ mod agent; +mod native_agent_server; mod prompts; mod templates; mod thread; @@ -8,4 +9,5 @@ mod tools; mod tests; pub use agent::*; +pub use native_agent_server::NativeAgentServer; pub use thread::*; diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs new file mode 100644 index 0000000000..b71fa455b5 --- /dev/null +++ b/crates/agent2/src/native_agent_server.rs @@ -0,0 +1,51 @@ +use std::path::Path; +use std::rc::Rc; + +use agent_servers::AgentServer; +use anyhow::Result; +use gpui::{App, AppContext, Entity, Task}; +use project::Project; + +use crate::{templates::Templates, NativeAgent, NativeAgentConnection}; + +#[derive(Clone)] +pub struct NativeAgentServer; + +impl AgentServer for NativeAgentServer { + fn name(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_headline(&self) -> &'static str { + "Native Agent" + } + + fn empty_state_message(&self) -> &'static str { + "How can I help you today?" + } + + fn logo(&self) -> ui::IconName { + // Using the ZedAssistant icon as it's the native built-in agent + ui::IconName::ZedAssistant + } + + fn connect( + &self, + _root_dir: &Path, + _project: &Entity, + cx: &mut App, + ) -> Task>> { + cx.spawn(async move |cx| { + // Create templates (you might want to load these from files or resources) + let templates = Templates::new(); + + // Create the native agent + let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?; + + // Create the connection wrapper + let connection = NativeAgentConnection(agent); + + Ok(Rc::new(connection) as Rc) + }) + } +} diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 95fd2b1757..c145df0eae 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -19,6 +19,7 @@ test-support = ["gpui/test-support", "language/test-support"] acp_thread.workspace = true agent-client-protocol.workspace = true agent.workspace = true +agent2.workspace = true agent_servers.workspace = true agent_settings.workspace = true ai_onboarding.workspace = true diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index a09c669769..5813b48f6e 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1954,40 +1954,54 @@ impl AgentPanel { this } }) - .when(cx.has_flag::(), |this| { - this.separator() - .header("External Agents") - .item( - ContextMenuEntry::new("New Gemini Thread") - .icon(IconName::AiGemini) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some(crate::ExternalAgent::Gemini), - } - .boxed_clone(), - cx, - ); - }), - ) - .item( - ContextMenuEntry::new("New Claude Code Thread") - .icon(IconName::AiClaude) - .icon_color(Color::Muted) - .handler(move |window, cx| { - window.dispatch_action( - NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::ClaudeCode, - ), - } - .boxed_clone(), - cx, - ); - }), - ) - }); + // Temporarily removed feature flag check for testing + // .when(cx.has_flag::(), |this| { + // this + .separator() + .header("External Agents") + .item( + ContextMenuEntry::new("New Gemini Thread") + .icon(IconName::AiGemini) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::Gemini), + } + .boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Claude Code Thread") + .icon(IconName::AiClaude) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::ClaudeCode), + } + .boxed_clone(), + cx, + ); + }), + ) + .item( + ContextMenuEntry::new("New Native Agent Thread") + .icon(IconName::ZedAssistant) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::NativeAgent), + } + .boxed_clone(), + cx, + ); + }), + ); + // }); menu })) } @@ -2594,63 +2608,87 @@ impl AgentPanel { ), ), ) - .when(cx.has_flag::(), |this| { - this.child( - h_flex() - .w_full() - .gap_2() - .child( - NewThreadButton::new( - "new-gemini-thread-btn", - "New Gemini Thread", - IconName::AiGemini, - ) - // .keybinding(KeyBinding::for_action_in( - // &OpenHistory, - // &self.focus_handle(cx), - // window, - // cx, - // )) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::Gemini, - ), - }), - cx, - ) - }, - ), + // Temporarily removed feature flag check for testing + // .when(cx.has_flag::(), |this| { + // this + .child( + h_flex() + .w_full() + .gap_2() + .child( + NewThreadButton::new( + "new-gemini-thread-btn", + "New Gemini Thread", + IconName::AiGemini, ) - .child( - NewThreadButton::new( - "new-claude-thread-btn", - "New Claude Code Thread", - IconName::AiClaude, - ) - // .keybinding(KeyBinding::for_action_in( - // &OpenHistory, - // &self.focus_handle(cx), - // window, - // cx, - // )) - .on_click( - |window, cx| { - window.dispatch_action( - Box::new(NewExternalAgentThread { - agent: Some( - crate::ExternalAgent::ClaudeCode, - ), - }), - cx, - ) - }, - ), + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some(crate::ExternalAgent::Gemini), + }), + cx, + ) + }, ), - ) - }), + ) + .child( + NewThreadButton::new( + "new-claude-thread-btn", + "New Claude Code Thread", + IconName::AiClaude, + ) + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::ClaudeCode, + ), + }), + cx, + ) + }, + ), + ) + .child( + NewThreadButton::new( + "new-native-agent-thread-btn", + "New Native Agent Thread", + IconName::ZedAssistant, + ) + // .keybinding(KeyBinding::for_action_in( + // &OpenHistory, + // &self.focus_handle(cx), + // window, + // cx, + // )) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::NativeAgent, + ), + }), + cx, + ) + }, + ), + ), + ), // }) ) .when_some(configuration_error.as_ref(), |this, err| { this.child(self.render_configuration_error(err, &focus_handle, window, cx)) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index c5574c2371..adbecb75cb 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -150,6 +150,7 @@ enum ExternalAgent { #[default] Gemini, ClaudeCode, + NativeAgent, } impl ExternalAgent { @@ -157,6 +158,7 @@ impl ExternalAgent { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), + ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer), } } } From f81993574ea0b0d96fc85c44d0139a8701c4f817 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sat, 2 Aug 2025 09:58:33 -0600 Subject: [PATCH 25/25] Connect Native Agent responses to UI display User-visible improvements: - Native Agent now shows AI responses in the chat interface - Uses configured default model from settings instead of random selection - Streams responses in real-time as the model generates them Technical changes: - Implemented response stream forwarding from Thread to AcpThread - Created Session struct to manage Thread and AcpThread together - Added proper SessionUpdate handling for text chunks and tool calls - Fixed model selection to use LanguageModelRegistry's default - Added comprehensive logging for debugging model interactions - Removed unused cwd parameter - native agent captures context differently than external agents --- crates/acp_thread/src/connection.rs | 2 +- crates/agent2/src/agent.rs | 242 ++++++++++++++++++++--- crates/agent2/src/native_agent_server.rs | 7 + crates/agent2/src/thread.rs | 111 ++++++++--- 4 files changed, 307 insertions(+), 55 deletions(-) diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index b99e4949d8..9659433edb 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -36,7 +36,7 @@ pub trait ModelSelector: 'static { /// A task resolving to `Ok(())` on success or an error. fn select_model( &self, - session_id: &acp::SessionId, + session_id: acp::SessionId, model: Arc, cx: &mut AsyncApp, ) -> Task>; diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 60c154971b..0491f7c81a 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,6 +1,6 @@ use acp_thread::ModelSelector; use agent_client_protocol as acp; -use anyhow::Result; +use anyhow::{anyhow, Result}; use gpui::{App, AppContext, AsyncApp, Entity, Task}; use language_model::{LanguageModel, LanguageModelRegistry}; use project::Project; @@ -11,15 +11,25 @@ use std::sync::Arc; use crate::{templates::Templates, Thread}; +/// Holds both the internal Thread and the AcpThread for a session +#[derive(Clone)] +struct Session { + /// The internal thread that processes messages + thread: Entity, + /// The ACP thread that handles protocol communication + acp_thread: Entity, +} + pub struct NativeAgent { - /// Session ID -> Thread entity mapping - sessions: HashMap>, + /// Session ID -> Session mapping + sessions: HashMap, /// Shared templates for all threads templates: Arc, } impl NativeAgent { pub fn new(templates: Arc) -> Self { + log::info!("Creating new NativeAgent"); Self { sessions: HashMap::new(), templates, @@ -33,10 +43,12 @@ pub struct NativeAgentConnection(pub Entity); impl ModelSelector for NativeAgentConnection { fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { + log::debug!("NativeAgentConnection::list_models called"); cx.spawn(async move |cx| { cx.update(|cx| { let registry = LanguageModelRegistry::read_global(cx); let models = registry.available_models(cx).collect::>(); + log::info!("Found {} available models", models.len()); if models.is_empty() { Err(anyhow::anyhow!("No models available")) } else { @@ -48,21 +60,26 @@ impl ModelSelector for NativeAgentConnection { fn select_model( &self, - session_id: &acp::SessionId, + session_id: acp::SessionId, model: Arc, cx: &mut AsyncApp, ) -> Task> { + log::info!( + "Setting model for session {}: {:?}", + session_id, + model.name() + ); let agent = self.0.clone(); - let session_id = session_id.clone(); + cx.spawn(async move |cx| { agent.update(cx, |agent, cx| { - if let Some(thread) = agent.sessions.get(&session_id) { - thread.update(cx, |thread, _| { + if let Some(session) = agent.sessions.get(&session_id) { + session.thread.update(cx, |thread, _cx| { thread.selected_model = model; }); Ok(()) } else { - Err(anyhow::anyhow!("Session not found")) + Err(anyhow!("Session not found")) } })? }) @@ -76,10 +93,12 @@ impl ModelSelector for NativeAgentConnection { let agent = self.0.clone(); let session_id = session_id.clone(); cx.spawn(async move |cx| { - let thread = agent + let session = agent .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())? .ok_or_else(|| anyhow::anyhow!("Session not found"))?; - let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + let selected = session + .thread + .read_with(cx, |thread, _| thread.selected_model.clone())?; Ok(selected) }) } @@ -92,32 +111,64 @@ impl acp_thread::AgentConnection for NativeAgentConnection { cwd: &Path, cx: &mut AsyncApp, ) -> Task>> { - let _cwd = cwd.to_owned(); let agent = self.0.clone(); + log::info!("Creating new thread for project at: {:?}", cwd); cx.spawn(async move |cx| { - // Create Thread and store in Agent - let (session_id, _thread) = - agent.update(cx, |agent, cx: &mut gpui::Context| { - // Fetch default model - let default_model = LanguageModelRegistry::read_global(cx) - .available_models(cx) - .next() - .unwrap_or_else(|| panic!("No default model available")); + log::debug!("Starting thread creation in async context"); + // Create Thread + let (session_id, thread) = agent.update( + cx, + |agent, cx: &mut gpui::Context| -> Result<_> { + // Fetch default model from registry settings + let registry = LanguageModelRegistry::read_global(cx); + + // Log available models for debugging + let available_count = registry.available_models(cx).count(); + log::debug!("Total available models: {}", available_count); + + let default_model = registry + .default_model() + .map(|configured| { + log::info!( + "Using configured default model: {:?} from provider: {:?}", + configured.model.name(), + configured.provider.name() + ); + configured.model + }) + .ok_or_else(|| { + log::warn!("No default model configured in settings"); + anyhow!("No default model configured. Please configure a default model in settings.") + })?; let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model)); + + // Generate session ID let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); - agent.sessions.insert(session_id.clone(), thread.clone()); - (session_id, thread) - })?; + log::info!("Created session with ID: {}", session_id); + Ok((session_id, thread)) + }, + )??; // Create AcpThread let acp_thread = cx.update(|cx| { cx.new(|cx| { - acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx) + acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx) }) })?; + // Store the session + agent.update(cx, |agent, _cx| { + agent.sessions.insert( + session_id, + Session { + thread, + acp_thread: acp_thread.clone(), + }, + ); + })?; + Ok(acp_thread) }) } @@ -137,28 +188,155 @@ impl acp_thread::AgentConnection for NativeAgentConnection { fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task> { let session_id = params.session_id.clone(); let agent = self.0.clone(); + log::info!("Received prompt request for session: {}", session_id); + log::debug!("Prompt blocks count: {}", params.prompt.len()); cx.spawn(async move |cx| { - // Get thread - let thread: Entity = agent - .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())? - .ok_or_else(|| anyhow::anyhow!("Session not found"))?; + // Get session + let session = agent + .read_with(cx, |agent, _| { + agent.sessions.get(&session_id).map(|s| Session { + thread: s.thread.clone(), + acp_thread: s.acp_thread.clone(), + }) + })? + .ok_or_else(|| { + log::error!("Session not found: {}", session_id); + anyhow::anyhow!("Session not found") + })?; + log::debug!("Found session for: {}", session_id); // Convert prompt to message let message = convert_prompt_to_message(params.prompt); + log::info!("Converted prompt to message: {} chars", message.len()); + log::debug!("Message content: {}", message); // Get model using the ModelSelector capability (always available for agent2) // Get the selected model from the thread directly - let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + let model = session + .thread + .read_with(cx, |thread, _| thread.selected_model.clone())?; // Send to thread - thread.update(cx, |thread, cx| thread.send(model, message, cx))?; + log::info!("Sending message to thread with model: {:?}", model.name()); + let response_stream = session + .thread + .update(cx, |thread, cx| thread.send(model, message, cx))?; + // Handle response stream and forward to session.acp_thread + let acp_thread = session.acp_thread.clone(); + cx.spawn(async move |cx| { + use futures::StreamExt; + use language_model::LanguageModelCompletionEvent; + + let mut response_stream = response_stream; + + while let Some(result) = response_stream.next().await { + match result { + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + + match event { + LanguageModelCompletionEvent::Text(text) => { + // Send text chunk as agent message + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AgentMessageChunk { + content: acp::ContentBlock::Text( + acp::TextContent { + text: text.into(), + annotations: None, + }, + ), + }, + cx, + ) + })??; + } + LanguageModelCompletionEvent::ToolUse(tool_use) => { + // Convert LanguageModelToolUse to ACP ToolCall + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + label: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(tool_use.input), + }), + cx, + ) + })??; + } + LanguageModelCompletionEvent::StartMessage { .. } => { + log::debug!("Started new assistant message"); + } + LanguageModelCompletionEvent::UsageUpdate(usage) => { + log::debug!("Token usage update: {:?}", usage); + } + LanguageModelCompletionEvent::Thinking { text, .. } => { + // Send thinking text as agent thought chunk + acp_thread.update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AgentThoughtChunk { + content: acp::ContentBlock::Text( + acp::TextContent { + text: text.into(), + annotations: None, + }, + ), + }, + cx, + ) + })??; + } + LanguageModelCompletionEvent::StatusUpdate(status) => { + log::trace!("Status update: {:?}", status); + } + LanguageModelCompletionEvent::Stop(stop_reason) => { + log::debug!("Assistant message complete: {:?}", stop_reason); + } + LanguageModelCompletionEvent::RedactedThinking { .. } => { + log::trace!("Redacted thinking event"); + } + LanguageModelCompletionEvent::ToolUseJsonParseError { + id, + tool_name, + raw_input, + json_parse_error, + } => { + log::error!( + "Tool use JSON parse error for tool '{}' (id: {}): {} - input: {}", + tool_name, + id, + json_parse_error, + raw_input + ); + } + } + } + Err(e) => { + log::error!("Error in model response stream: {:?}", e); + // TODO: Consider sending an error message to the UI + break; + } + } + } + + log::info!("Response stream completed"); + anyhow::Ok(()) + }) + .detach(); + + log::info!("Successfully sent prompt to thread and started response handler"); Ok(()) }) } fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + log::info!("Cancelling session: {}", session_id); self.0.update(cx, |agent, _cx| { agent.sessions.remove(session_id); }); @@ -167,23 +345,29 @@ impl acp_thread::AgentConnection for NativeAgentConnection { /// Convert ACP content blocks to a message string fn convert_prompt_to_message(blocks: Vec) -> String { + log::debug!("Converting {} content blocks to message", blocks.len()); let mut message = String::new(); for block in blocks { match block { acp::ContentBlock::Text(text) => { + log::trace!("Processing text block: {} chars", text.text.len()); message.push_str(&text.text); } acp::ContentBlock::ResourceLink(link) => { + log::trace!("Processing resource link: {}", link.uri); message.push_str(&format!(" @{} ", link.uri)); } acp::ContentBlock::Image(_) => { + log::trace!("Processing image block"); message.push_str(" [image] "); } acp::ContentBlock::Audio(_) => { + log::trace!("Processing audio block"); message.push_str(" [audio] "); } acp::ContentBlock::Resource(resource) => { + log::trace!("Processing resource block: {:?}", resource.resource); message.push_str(&format!(" [resource: {:?}] ", resource.resource)); } } diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs index b71fa455b5..aafe70a8a2 100644 --- a/crates/agent2/src/native_agent_server.rs +++ b/crates/agent2/src/native_agent_server.rs @@ -35,15 +35,22 @@ impl AgentServer for NativeAgentServer { _project: &Entity, cx: &mut App, ) -> Task>> { + log::info!( + "NativeAgentServer::connect called for path: {:?}", + _root_dir + ); cx.spawn(async move |cx| { + log::debug!("Creating templates for native agent"); // Create templates (you might want to load these from files or resources) let templates = Templates::new(); // Create the native agent + log::debug!("Creating native agent entity"); let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?; // Create the connection wrapper let connection = NativeAgentConnection(agent); + log::info!("NativeAgentServer connection established successfully"); Ok(Rc::new(connection) as Rc) }) diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 758e940269..4bc2932aeb 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -9,6 +9,7 @@ use language_model::{ LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, Role, StopReason, }; +use log; use schemars::{JsonSchema, Schema}; use serde::Deserialize; use smol::stream::StreamExt; @@ -38,7 +39,6 @@ pub struct Thread { tools: BTreeMap>, templates: Arc, pub selected_model: Arc, - // project: Entity, // action_log: Entity, } @@ -80,22 +80,36 @@ impl Thread { content: impl Into, cx: &mut Context, ) -> mpsc::UnboundedReceiver> { + let content = content.into(); + log::info!("Thread::send called with model: {:?}", model.name()); + log::debug!("Thread::send content: {:?}", content); + cx.notify(); let (events_tx, events_rx) = mpsc::unbounded::>(); let system_message = self.build_system_message(cx); + log::debug!( + "System messages count: {}", + if system_message.is_some() { 1 } else { 0 } + ); self.messages.extend(system_message); self.messages.push(AgentMessage { role: Role::User, - content: vec![content.into()], + content: vec![content], }); + log::info!("Total messages in thread: {}", self.messages.len()); self.running_turn = Some(cx.spawn(async move |thread, cx| { + log::info!("Starting agent turn execution"); let turn_result = async { // Perform one request, then keep looping if the model makes tool calls. let mut completion_intent = CompletionIntent::UserPrompt; loop { + log::debug!( + "Building completion request with intent: {:?}", + completion_intent + ); let request = thread.update(cx, |thread, cx| { thread.build_completion_request(completion_intent, cx) })?; @@ -106,11 +120,14 @@ impl Thread { // ); // Stream events, appending to messages and collecting up tool uses. + log::info!("Calling model.stream_completion"); let mut events = model.stream_completion(request, cx).await?; + log::debug!("Stream completion started successfully"); let mut tool_uses = Vec::new(); while let Some(event) = events.next().await { match event { Ok(event) => { + log::trace!("Received completion event: {:?}", event); thread .update(cx, |thread, cx| { tool_uses.extend(thread.handle_streamed_completion_event( @@ -122,6 +139,7 @@ impl Thread { .ok(); } Err(error) => { + log::error!("Error in completion stream: {:?}", error); events_tx.unbounded_send(Err(error)).ok(); break; } @@ -130,13 +148,16 @@ impl Thread { // If there are no tool uses, the turn is done. if tool_uses.is_empty() { + log::info!("No tool uses found, completing turn"); break; } + log::info!("Found {} tool uses to execute", tool_uses.len()); // If there are tool uses, wait for their results to be // computed, then send them together in a single message on // the next loop iteration. let tool_results = future::join_all(tool_uses).await; + log::debug!("Tool execution completed, {} results", tool_results.len()); thread .update(cx, |thread, _cx| { thread.messages.push(AgentMessage { @@ -156,13 +177,17 @@ impl Thread { .await; if let Err(error) = turn_result { + log::error!("Turn execution failed: {:?}", error); events_tx.unbounded_send(Err(error)).ok(); + } else { + log::info!("Turn execution completed successfully"); } })); events_rx } pub fn build_system_message(&mut self, cx: &App) -> Option { + log::debug!("Building system message"); let mut system_message = AgentMessage { role: Role::System, content: Vec::new(), @@ -176,7 +201,9 @@ impl Thread { } } - (!system_message.content.is_empty()).then_some(system_message) + let result = (!system_message.content.is_empty()).then_some(system_message); + log::debug!("System message built: {}", result.is_some()); + result } /// A helper method that's called on every streamed completion event. @@ -188,6 +215,7 @@ impl Thread { events_tx: mpsc::UnboundedSender>, cx: &mut Context, ) -> Option> { + log::trace!("Handling streamed completion event: {:?}", event); use LanguageModelCompletionEvent::*; events_tx.unbounded_send(Ok(event.clone())).ok(); @@ -329,41 +357,74 @@ impl Thread { completion_intent: CompletionIntent, cx: &mut App, ) -> LanguageModelRequest { - LanguageModelRequest { + log::debug!("Building completion request"); + log::debug!("Completion intent: {:?}", completion_intent); + log::debug!("Completion mode: {:?}", self.completion_mode); + + let messages = self.build_request_messages(); + log::info!("Request will include {} messages", messages.len()); + + let tools: Vec = self + .tools + .values() + .filter_map(|tool| { + let tool_name = tool.name().to_string(); + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name, + description: tool.description(cx).to_string(), + input_schema: tool + .input_schema(LanguageModelToolSchemaFormat::JsonSchema) + .log_err()?, + }) + }) + .collect(); + + log::info!("Request includes {} tools", tools.len()); + + let request = LanguageModelRequest { thread_id: None, prompt_id: None, intent: Some(completion_intent), mode: Some(self.completion_mode), - messages: self.build_request_messages(), - tools: self - .tools - .values() - .filter_map(|tool| { - Some(LanguageModelRequestTool { - name: tool.name().to_string(), - description: tool.description(cx).to_string(), - input_schema: tool - .input_schema(LanguageModelToolSchemaFormat::JsonSchema) - .log_err()?, - }) - }) - .collect(), + messages, + tools, tool_choice: None, stop: Vec::new(), temperature: None, thinking_allowed: false, - } + }; + + log::debug!("Completion request built successfully"); + request } fn build_request_messages(&self) -> Vec { - self.messages + log::trace!( + "Building request messages from {} thread messages", + self.messages.len() + ); + let messages = self + .messages .iter() - .map(|message| LanguageModelRequestMessage { - role: message.role, - content: message.content.clone(), - cache: false, + .map(|message| { + log::trace!( + " - {} message with {} content items", + match message.role { + Role::System => "System", + Role::User => "User", + Role::Assistant => "Assistant", + }, + message.content.len() + ); + LanguageModelRequestMessage { + role: message.role, + content: message.content.clone(), + cache: false, + } }) - .collect() + .collect(); + messages } }