From c2fc70eef7454a49de9f87d38f34b65392752745 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 28 Jul 2025 11:14:10 -0300 Subject: [PATCH] ACP over MCP server impl (#35196) Release Notes: - N/A --------- Co-authored-by: Ben Brandt --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/acp_thread/src/acp_thread.rs | 299 +++++++++++++++----- crates/acp_thread/src/connection.rs | 2 +- crates/acp_thread/src/old_acp_support.rs | 58 ++-- crates/agent_servers/Cargo.toml | 1 + crates/agent_servers/src/agent_servers.rs | 3 + crates/agent_servers/src/claude.rs | 20 +- crates/agent_servers/src/claude/tools.rs | 1 + crates/agent_servers/src/codex.rs | 317 ++++++++++++++++++++++ crates/agent_servers/src/e2e_tests.rs | 3 + crates/agent_servers/src/mcp_server.rs | 201 ++++++++++++++ crates/agent_servers/src/settings.rs | 11 +- crates/agent_ui/src/acp/thread_view.rs | 15 +- crates/agent_ui/src/agent_panel.rs | 33 +++ crates/agent_ui/src/agent_ui.rs | 2 + crates/context_server/src/client.rs | 40 ++- crates/context_server/src/listener.rs | 2 +- crates/context_server/src/protocol.rs | 9 +- crates/context_server/src/types.rs | 13 + 20 files changed, 899 insertions(+), 137 deletions(-) create mode 100644 crates/agent_servers/src/codex.rs create mode 100644 crates/agent_servers/src/mcp_server.rs diff --git a/Cargo.lock b/Cargo.lock index 5f746a02fa..8d9a622655 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,9 +138,9 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.10" +version = "0.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fb7f39671e02f8a1aeb625652feae40b6fc2597baaa97e028a98863477aecbd" +checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b" dependencies = [ "schemars", "serde", diff --git a/Cargo.toml b/Cargo.toml index 39b60dda01..16ace7dee0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.10" +agent-client-protocol = "0.0.11" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 3c6c21205f..d572992c54 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -166,6 +166,7 @@ pub struct ToolCall { pub content: Vec, pub status: ToolCallStatus, pub locations: Vec, + pub raw_input: Option, } impl ToolCall { @@ -193,6 +194,50 @@ impl ToolCall { .collect(), locations: tool_call.locations, status, + raw_input: tool_call.raw_input, + } + } + + fn update( + &mut self, + fields: acp::ToolCallUpdateFields, + language_registry: Arc, + cx: &mut App, + ) { + let acp::ToolCallUpdateFields { + kind, + status, + label, + content, + locations, + raw_input, + } = fields; + + if let Some(kind) = kind { + self.kind = kind; + } + + if let Some(status) = status { + self.status = ToolCallStatus::Allowed { status }; + } + + if let Some(label) = label { + self.label = cx.new(|cx| Markdown::new_text(label.into(), cx)); + } + + if let Some(content) = content { + self.content = content + .into_iter() + .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx)) + .collect(); + } + + if let Some(locations) = locations { + self.locations = locations; + } + + if let Some(raw_input) = raw_input { + self.raw_input = Some(raw_input); } } @@ -238,6 +283,7 @@ impl Display for ToolCallStatus { match self { ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation", ToolCallStatus::Allowed { status } => match status { + acp::ToolCallStatus::Pending => "Pending", acp::ToolCallStatus::InProgress => "In Progress", acp::ToolCallStatus::Completed => "Completed", acp::ToolCallStatus::Failed => "Failed", @@ -345,7 +391,7 @@ impl ToolCallContent { cx: &mut App, ) -> Self { match content { - acp::ToolCallContent::ContentBlock { content } => Self::ContentBlock { + acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock { content: ContentBlock::new(content, &language_registry, cx), }, acp::ToolCallContent::Diff { diff } => Self::Diff { @@ -630,12 +676,50 @@ impl AcpThread { false } - pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { - self.entries.push(entry); - cx.emit(AcpThreadEvent::NewEntry); + pub fn handle_session_update( + &mut self, + update: acp::SessionUpdate, + cx: &mut Context, + ) -> Result<()> { + match update { + acp::SessionUpdate::UserMessage(content_block) => { + self.push_user_content_block(content_block, cx); + } + acp::SessionUpdate::AgentMessageChunk(content_block) => { + self.push_assistant_content_block(content_block, false, cx); + } + acp::SessionUpdate::AgentThoughtChunk(content_block) => { + self.push_assistant_content_block(content_block, true, cx); + } + acp::SessionUpdate::ToolCall(tool_call) => { + self.upsert_tool_call(tool_call, cx); + } + acp::SessionUpdate::ToolCallUpdate(tool_call_update) => { + self.update_tool_call(tool_call_update, cx)?; + } + acp::SessionUpdate::Plan(plan) => { + self.update_plan(plan, cx); + } + } + Ok(()) } - pub fn push_assistant_chunk( + pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context) { + let language_registry = self.project.read(cx).languages().clone(); + let entries_len = self.entries.len(); + + if let Some(last_entry) = self.entries.last_mut() + && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry + { + content.append(chunk, &language_registry, cx); + cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + } else { + let content = ContentBlock::new(chunk, &language_registry, cx); + self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx); + } + } + + pub fn push_assistant_content_block( &mut self, chunk: acp::ContentBlock, is_thought: bool, @@ -678,23 +762,22 @@ impl AcpThread { } } + fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context) { + self.entries.push(entry); + cx.emit(AcpThreadEvent::NewEntry); + } + pub fn update_tool_call( &mut self, - id: acp::ToolCallId, - status: acp::ToolCallStatus, - content: Option>, + update: acp::ToolCallUpdate, cx: &mut Context, ) -> Result<()> { let languages = self.project.read(cx).languages().clone(); - let (ix, current_call) = self.tool_call_mut(&id).context("Tool call not found")?; - if let Some(content) = content { - current_call.content = content - .into_iter() - .map(|chunk| ToolCallContent::from_acp(chunk, languages.clone(), cx)) - .collect(); - } - current_call.status = ToolCallStatus::Allowed { status }; + let (ix, current_call) = self + .tool_call_mut(&update.id) + .context("Tool call not found")?; + current_call.update(update.fields, languages, cx); cx.emit(AcpThreadEvent::EntryUpdated(ix)); @@ -751,6 +834,37 @@ impl AcpThread { }) } + pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context) { + self.project.update(cx, |project, cx| { + let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else { + return; + }; + let buffer = project.open_buffer(path, cx); + cx.spawn(async move |project, cx| { + let buffer = buffer.await?; + + project.update(cx, |project, cx| { + let position = if let Some(line) = location.line { + let snapshot = buffer.read(cx).snapshot(); + let point = snapshot.clip_point(Point::new(line, 0), Bias::Left); + snapshot.anchor_before(point) + } else { + Anchor::MIN + }; + + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position, + }), + cx, + ); + }) + }) + .detach_and_log_err(cx); + }); + } + pub fn request_tool_call_permission( &mut self, tool_call: acp::ToolCall, @@ -801,6 +915,25 @@ impl AcpThread { cx.emit(AcpThreadEvent::EntryUpdated(ix)); } + /// Returns true if the last turn is awaiting tool authorization + pub fn waiting_for_tool_confirmation(&self) -> bool { + for entry in self.entries.iter().rev() { + match &entry { + AgentThreadEntry::ToolCall(call) => match call.status { + ToolCallStatus::WaitingForConfirmation { .. } => return true, + ToolCallStatus::Allowed { .. } + | ToolCallStatus::Rejected + | ToolCallStatus::Canceled => continue, + }, + AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => { + // Reached the beginning of the turn + return false; + } + } + } + false + } + pub fn plan(&self) -> &Plan { &self.plan } @@ -824,56 +957,6 @@ impl AcpThread { cx.notify(); } - pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context) { - self.project.update(cx, |project, cx| { - let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else { - return; - }; - let buffer = project.open_buffer(path, cx); - cx.spawn(async move |project, cx| { - let buffer = buffer.await?; - - project.update(cx, |project, cx| { - let position = if let Some(line) = location.line { - let snapshot = buffer.read(cx).snapshot(); - let point = snapshot.clip_point(Point::new(line, 0), Bias::Left); - snapshot.anchor_before(point) - } else { - Anchor::MIN - }; - - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position, - }), - cx, - ); - }) - }) - .detach_and_log_err(cx); - }); - } - - /// Returns true if the last turn is awaiting tool authorization - pub fn waiting_for_tool_confirmation(&self) -> bool { - for entry in self.entries.iter().rev() { - match &entry { - AgentThreadEntry::ToolCall(call) => match call.status { - ToolCallStatus::WaitingForConfirmation { .. } => return true, - ToolCallStatus::Allowed { .. } - | ToolCallStatus::Rejected - | ToolCallStatus::Canceled => continue, - }, - AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => { - // Reached the beginning of the turn - return false; - } - } - } - false - } - pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future> { self.connection.authenticate(cx) } @@ -919,7 +1002,7 @@ impl AcpThread { let result = this .update(cx, |this, cx| { this.connection.prompt( - acp::PromptToolArguments { + acp::PromptArguments { prompt: message, session_id: this.session_id.clone(), }, @@ -1148,7 +1231,87 @@ mod tests { } #[gpui::test] - async fn test_thinking_concatenation(cx: &mut TestAppContext) { + async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (thread, _fake_server) = fake_acp_thread(project, cx); + + // Test creating a new user message + thread.update(cx, |thread, cx| { + thread.push_user_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "Hello, ".to_string(), + }), + cx, + ); + }); + + thread.update(cx, |thread, cx| { + assert_eq!(thread.entries.len(), 1); + if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.content.to_markdown(cx), "Hello, "); + } else { + panic!("Expected UserMessage"); + } + }); + + // Test appending to existing user message + thread.update(cx, |thread, cx| { + thread.push_user_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "world!".to_string(), + }), + cx, + ); + }); + + thread.update(cx, |thread, cx| { + assert_eq!(thread.entries.len(), 1); + if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!"); + } else { + panic!("Expected UserMessage"); + } + }); + + // Test creating new user message after assistant message + thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "Assistant response".to_string(), + }), + false, + cx, + ); + }); + + thread.update(cx, |thread, cx| { + thread.push_user_content_block( + acp::ContentBlock::Text(acp::TextContent { + annotations: None, + text: "New user message".to_string(), + }), + cx, + ); + }); + + thread.update(cx, |thread, cx| { + assert_eq!(thread.entries.len(), 3); + if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] { + assert_eq!(user_msg.content.to_markdown(cx), "New user message"); + } else { + panic!("Expected UserMessage at index 2"); + } + }); + } + + #[gpui::test] + async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) { init_test(cx); let fs = FakeFs::new(cx.executor()); diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index fde167da5f..5b25b71863 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -20,7 +20,7 @@ pub trait AgentConnection { fn authenticate(&self, cx: &mut App) -> Task>; - fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task>; + fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task>; fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); } diff --git a/crates/acp_thread/src/old_acp_support.rs b/crates/acp_thread/src/old_acp_support.rs index 316a5bcf25..44cd00348f 100644 --- a/crates/acp_thread/src/old_acp_support.rs +++ b/crates/acp_thread/src/old_acp_support.rs @@ -8,7 +8,7 @@ use project::Project; use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc}; use ui::App; -use crate::{AcpThread, AcpThreadEvent, AgentConnection, ToolCallContent, ToolCallStatus}; +use crate::{AcpThread, AgentConnection}; #[derive(Clone)] pub struct OldAcpClientDelegate { @@ -40,10 +40,10 @@ impl acp_old::Client for OldAcpClientDelegate { .borrow() .update(cx, |thread, cx| match params.chunk { acp_old::AssistantMessageChunk::Text { text } => { - thread.push_assistant_chunk(text.into(), false, cx) + thread.push_assistant_content_block(text.into(), false, cx) } acp_old::AssistantMessageChunk::Thought { thought } => { - thread.push_assistant_chunk(thought.into(), true, cx) + thread.push_assistant_content_block(thought.into(), true, cx) } }) .ok(); @@ -182,31 +182,23 @@ impl acp_old::Client for OldAcpClientDelegate { cx.update(|cx| { self.thread.borrow().update(cx, |thread, cx| { - let languages = thread.project.read(cx).languages().clone(); - - if let Some((ix, tool_call)) = thread - .tool_call_mut(&acp::ToolCallId(request.tool_call_id.0.to_string().into())) - { - tool_call.status = ToolCallStatus::Allowed { - status: into_new_tool_call_status(request.status), - }; - tool_call.content = request - .content - .into_iter() - .map(|content| { - ToolCallContent::from_acp( - into_new_tool_call_content(content), - languages.clone(), - cx, - ) - }) - .collect(); - - cx.emit(AcpThreadEvent::EntryUpdated(ix)); - anyhow::Ok(()) - } else { - anyhow::bail!("Tool call not found") - } + thread.update_tool_call( + acp::ToolCallUpdate { + id: acp::ToolCallId(request.tool_call_id.0.to_string().into()), + fields: acp::ToolCallUpdateFields { + status: Some(into_new_tool_call_status(request.status)), + content: Some( + request + .content + .into_iter() + .map(into_new_tool_call_content) + .collect::>(), + ), + ..Default::default() + }, + }, + cx, + ) }) })? .context("Failed to update thread")??; @@ -285,6 +277,7 @@ fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) .into_iter() .map(into_new_tool_call_location) .collect(), + raw_input: None, } } @@ -311,12 +304,7 @@ fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallSt fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent { match content { - acp_old::ToolCallContent::Markdown { markdown } => acp::ToolCallContent::ContentBlock { - content: acp::ContentBlock::Text(acp::TextContent { - annotations: None, - text: markdown, - }), - }, + acp_old::ToolCallContent::Markdown { markdown } => markdown.into(), acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff { diff: into_new_diff(diff), }, @@ -423,7 +411,7 @@ impl AgentConnection for OldAcpAgentConnection { }) } - fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { let chunks = params .prompt .into_iter() diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 023799bc6c..dcffb05bc0 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -41,6 +41,7 @@ ui.workspace = true util.workspace = true uuid.workspace = true watch.workspace = true +indoc.workspace = true which.workspace = true workspace-hack.workspace = true diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 660f61f907..212bb74d8a 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -1,11 +1,14 @@ mod claude; +mod codex; mod gemini; +mod mcp_server; mod settings; #[cfg(test)] mod e2e_tests; pub use claude::*; +pub use codex::*; pub use gemini::*; pub use settings::*; diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 2e0eb271b6..4b48dbf3c1 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -44,7 +44,7 @@ impl AgentServer for ClaudeCode { } fn empty_state_message(&self) -> &'static str { - "" + "How can I help you today?" } fn logo(&self) -> ui::IconName { @@ -190,7 +190,7 @@ impl AgentConnection for ClaudeAgentConnection { Task::ready(Err(anyhow!("Authentication not supported"))) } - fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task> { + fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { let sessions = self.sessions.borrow(); let Some(session) = sessions.get(¶ms.session_id) else { return Task::ready(Err(anyhow!( @@ -350,7 +350,7 @@ impl ClaudeAgentSession { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { thread .update(cx, |thread, cx| { - thread.push_assistant_chunk(text.into(), false, cx) + thread.push_assistant_content_block(text.into(), false, cx) }) .log_err(); } @@ -387,9 +387,15 @@ impl ClaudeAgentSession { thread .update(cx, |thread, cx| { thread.update_tool_call( - acp::ToolCallId(tool_use_id.into()), - acp::ToolCallStatus::Completed, - (!content.is_empty()).then(|| vec![content.into()]), + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use_id.into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + content: (!content.is_empty()) + .then(|| vec![content.into()]), + ..Default::default() + }, + }, cx, ) }) @@ -402,7 +408,7 @@ impl ClaudeAgentSession { | ContentChunk::WebSearchToolResult => { thread .update(cx, |thread, cx| { - thread.push_assistant_chunk( + thread.push_assistant_content_block( format!("Unsupported content: {:?}", chunk).into(), false, cx, diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index ed25f9af7f..6acb6355aa 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -311,6 +311,7 @@ impl ClaudeTool { label: self.label(), content: self.content(), locations: self.locations(), + raw_input: None, } } } diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs new file mode 100644 index 0000000000..3eb95a6841 --- /dev/null +++ b/crates/agent_servers/src/codex.rs @@ -0,0 +1,317 @@ +use agent_client_protocol as acp; +use anyhow::anyhow; +use collections::HashMap; +use context_server::listener::McpServerTool; +use context_server::types::requests; +use context_server::{ContextServer, ContextServerCommand, ContextServerId}; +use futures::channel::{mpsc, oneshot}; +use project::Project; +use settings::SettingsStore; +use smol::stream::StreamExt as _; +use std::cell::RefCell; +use std::rc::Rc; +use std::{path::Path, sync::Arc}; +use util::ResultExt; + +use anyhow::{Context, Result}; +use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; + +use crate::mcp_server::ZedMcpServer; +use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server}; +use acp_thread::{AcpThread, AgentConnection}; + +#[derive(Clone)] +pub struct Codex; + +impl AgentServer for Codex { + fn name(&self) -> &'static str { + "Codex" + } + + fn empty_state_headline(&self) -> &'static str { + "Welcome to Codex" + } + + fn empty_state_message(&self) -> &'static str { + "What can I help with?" + } + + fn logo(&self) -> ui::IconName { + ui::IconName::AiOpenAi + } + + fn connect( + &self, + _root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + let project = project.clone(); + cx.spawn(async move |cx| { + let settings = cx.read_global(|settings: &SettingsStore, _| { + settings.get::(None).codex.clone() + })?; + + let Some(command) = + AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await + else { + anyhow::bail!("Failed to find codex binary"); + }; + + let client: Arc = ContextServer::stdio( + ContextServerId("codex-mcp-server".into()), + ContextServerCommand { + path: command.path, + args: command.args, + env: command.env, + }, + ) + .into(); + ContextServer::start(client.clone(), cx).await?; + + let (notification_tx, mut notification_rx) = mpsc::unbounded(); + client + .client() + .context("Failed to subscribe")? + .on_notification(acp::SESSION_UPDATE_METHOD_NAME, { + move |notification, _cx| { + let notification_tx = notification_tx.clone(); + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(notification) = + serde_json::from_value::(notification) + .log_err() + { + notification_tx.unbounded_send(notification).ok(); + } + } + }); + + let sessions = Rc::new(RefCell::new(HashMap::default())); + + let notification_handler_task = cx.spawn({ + let sessions = sessions.clone(); + async move |cx| { + while let Some(notification) = notification_rx.next().await { + CodexConnection::handle_session_notification( + notification, + sessions.clone(), + cx, + ) + } + } + }); + + let connection = CodexConnection { + client, + sessions, + _notification_handler_task: notification_handler_task, + }; + Ok(Rc::new(connection) as _) + }) + } +} + +struct CodexConnection { + client: Arc, + sessions: Rc>>, + _notification_handler_task: Task<()>, +} + +struct CodexSession { + thread: WeakEntity, + cancel_tx: Option>, + _mcp_server: ZedMcpServer, +} + +impl AgentConnection for CodexConnection { + fn name(&self) -> &'static str { + "Codex" + } + + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut AsyncApp, + ) -> Task>> { + let client = self.client.client(); + let sessions = self.sessions.clone(); + let cwd = cwd.to_path_buf(); + cx.spawn(async move |cx| { + let client = client.context("MCP server is not initialized yet")?; + let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); + + let mcp_server = ZedMcpServer::new(thread_rx, cx).await?; + + let response = client + .request::(context_server::types::CallToolParams { + name: acp::NEW_SESSION_TOOL_NAME.into(), + arguments: Some(serde_json::to_value(acp::NewSessionArguments { + mcp_servers: [( + mcp_server::SERVER_NAME.to_string(), + mcp_server.server_config()?, + )] + .into(), + client_tools: acp::ClientTools { + request_permission: Some(acp::McpToolId { + mcp_server: mcp_server::SERVER_NAME.into(), + tool_name: mcp_server::RequestPermissionTool::NAME.into(), + }), + read_text_file: Some(acp::McpToolId { + mcp_server: mcp_server::SERVER_NAME.into(), + tool_name: mcp_server::ReadTextFileTool::NAME.into(), + }), + write_text_file: Some(acp::McpToolId { + mcp_server: mcp_server::SERVER_NAME.into(), + tool_name: mcp_server::WriteTextFileTool::NAME.into(), + }), + }, + cwd, + })?), + meta: None, + }) + .await?; + + if response.is_error.unwrap_or_default() { + return Err(anyhow!(response.text_contents())); + } + + let result = serde_json::from_value::( + response.structured_content.context("Empty response")?, + )?; + + let thread = + cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?; + + thread_tx.send(thread.downgrade())?; + + let session = CodexSession { + thread: thread.downgrade(), + cancel_tx: None, + _mcp_server: mcp_server, + }; + sessions.borrow_mut().insert(result.session_id, session); + + Ok(thread) + }) + } + + fn authenticate(&self, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow!("Authentication not supported"))) + } + + fn prompt( + &self, + params: agent_client_protocol::PromptArguments, + cx: &mut App, + ) -> Task> { + let client = self.client.client(); + let sessions = self.sessions.clone(); + + cx.foreground_executor().spawn(async move { + let client = client.context("MCP server is not initialized yet")?; + + let (new_cancel_tx, cancel_rx) = oneshot::channel(); + { + let mut sessions = sessions.borrow_mut(); + let session = sessions + .get_mut(¶ms.session_id) + .context("Session not found")?; + session.cancel_tx.replace(new_cancel_tx); + } + + let result = client + .request_with::( + context_server::types::CallToolParams { + name: acp::PROMPT_TOOL_NAME.into(), + arguments: Some(serde_json::to_value(params)?), + meta: None, + }, + Some(cancel_rx), + None, + ) + .await; + + if let Err(err) = &result + && err.is::() + { + return Ok(()); + } + + let response = result?; + + if response.is_error.unwrap_or_default() { + return Err(anyhow!(response.text_contents())); + } + + Ok(()) + }) + } + + fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) { + let mut sessions = self.sessions.borrow_mut(); + + if let Some(cancel_tx) = sessions + .get_mut(session_id) + .and_then(|session| session.cancel_tx.take()) + { + cancel_tx.send(()).ok(); + } + } +} + +impl CodexConnection { + pub fn handle_session_notification( + notification: acp::SessionNotification, + threads: Rc>>, + cx: &mut AsyncApp, + ) { + let threads = threads.borrow(); + let Some(thread) = threads + .get(¬ification.session_id) + .and_then(|session| session.thread.upgrade()) + else { + log::error!( + "Thread not found for session ID: {}", + notification.session_id + ); + return; + }; + + thread + .update(cx, |thread, cx| { + thread.handle_session_update(notification.update, cx) + }) + .log_err(); + } +} + +impl Drop for CodexConnection { + fn drop(&mut self) { + self.client.stop().log_err(); + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::AgentServerCommand; + use std::path::Path; + + crate::common_e2e_tests!(Codex); + + pub fn local_command() -> AgentServerCommand { + let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../../../codex/codex-rs/target/debug/codex"); + + AgentServerCommand { + path: cli_path, + args: vec!["mcp".into()], + env: None, + } + } +} diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 9bc6fd60fe..905f06a148 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -351,6 +351,9 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc { gemini: Some(AgentServerSettings { command: crate::gemini::tests::local_command(), }), + codex: Some(AgentServerSettings { + command: crate::codex::tests::local_command(), + }), }, cx, ); diff --git a/crates/agent_servers/src/mcp_server.rs b/crates/agent_servers/src/mcp_server.rs new file mode 100644 index 0000000000..47575fa3ea --- /dev/null +++ b/crates/agent_servers/src/mcp_server.rs @@ -0,0 +1,201 @@ +use acp_thread::AcpThread; +use agent_client_protocol as acp; +use anyhow::{Context, Result}; +use context_server::listener::{McpServerTool, ToolResponse}; +use context_server::types::{ + Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, + ToolsCapabilities, requests, +}; +use futures::channel::oneshot; +use gpui::{App, AsyncApp, Task, WeakEntity}; +use indoc::indoc; + +pub struct ZedMcpServer { + server: context_server::listener::McpServer, +} + +pub const SERVER_NAME: &str = "zed"; + +impl ZedMcpServer { + pub async fn new( + thread_rx: watch::Receiver>, + cx: &AsyncApp, + ) -> Result { + let mut mcp_server = context_server::listener::McpServer::new(cx).await?; + mcp_server.handle_request::(Self::handle_initialize); + + mcp_server.add_tool(RequestPermissionTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(ReadTextFileTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(WriteTextFileTool { + thread_rx: thread_rx.clone(), + }); + + Ok(Self { server: mcp_server }) + } + + pub fn server_config(&self) -> Result { + let zed_path = std::env::current_exe() + .context("finding current executable path for use in mcp_server")?; + + Ok(acp::McpServerConfig { + command: zed_path, + args: vec![ + "--nc".into(), + self.server.socket_path().display().to_string(), + ], + env: None, + }) + } + + fn handle_initialize(_: InitializeParams, cx: &App) -> Task> { + cx.foreground_executor().spawn(async move { + Ok(InitializeResponse { + protocol_version: ProtocolVersion("2025-06-18".into()), + capabilities: ServerCapabilities { + experimental: None, + logging: None, + completions: None, + prompts: None, + resources: None, + tools: Some(ToolsCapabilities { + list_changed: Some(false), + }), + }, + server_info: Implementation { + name: SERVER_NAME.into(), + version: "0.1.0".into(), + }, + meta: None, + }) + }) + } +} + +// Tools + +#[derive(Clone)] +pub struct RequestPermissionTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for RequestPermissionTool { + type Input = acp::RequestPermissionArguments; + type Output = acp::RequestPermissionOutput; + + const NAME: &'static str = "Confirmation"; + + fn description(&self) -> &'static str { + indoc! {" + Request permission for tool calls. + + This tool is meant to be called programmatically by the agent loop, not the LLM. + "} + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let result = thread + .update(cx, |thread, cx| { + thread.request_tool_call_permission(input.tool_call, input.options, cx) + })? + .await; + + let outcome = match result { + Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id }, + Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled, + }; + + Ok(ToolResponse { + content: vec![], + structured_content: acp::RequestPermissionOutput { outcome }, + }) + } +} + +#[derive(Clone)] +pub struct ReadTextFileTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for ReadTextFileTool { + type Input = acp::ReadTextFileArguments; + type Output = acp::ReadTextFileOutput; + + const NAME: &'static str = "Read"; + + fn description(&self) -> &'static str { + "Reads the content of the given file in the project including unsaved changes." + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.path, input.line, input.limit, false, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![], + structured_content: acp::ReadTextFileOutput { content }, + }) + } +} + +#[derive(Clone)] +pub struct WriteTextFileTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for WriteTextFileTool { + type Input = acp::WriteTextFileArguments; + type Output = (); + + const NAME: &'static str = "Write"; + + fn description(&self) -> &'static str { + "Write to a file replacing its contents" + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + thread + .update(cx, |thread, cx| { + thread.write_text_file(input.path, input.content, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![], + structured_content: (), + }) + } +} diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs index 645674b5f1..aeb34a5e61 100644 --- a/crates/agent_servers/src/settings.rs +++ b/crates/agent_servers/src/settings.rs @@ -13,6 +13,7 @@ pub fn init(cx: &mut App) { pub struct AllAgentServersSettings { pub gemini: Option, pub claude: Option, + pub codex: Option, } #[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)] @@ -29,13 +30,21 @@ impl settings::Settings for AllAgentServersSettings { fn load(sources: SettingsSources, _: &mut App) -> Result { let mut settings = AllAgentServersSettings::default(); - for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() { + for AllAgentServersSettings { + gemini, + claude, + codex, + } in sources.defaults_and_customizations() + { if gemini.is_some() { settings.gemini = gemini.clone(); } if claude.is_some() { settings.claude = claude.clone(); } + if codex.is_some() { + settings.codex = codex.clone(); + } } Ok(settings) diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 7f5de9db5f..e46e1ae3ab 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -872,7 +872,10 @@ impl AcpThreadView { let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix)); let status_icon = match &tool_call.status { - ToolCallStatus::WaitingForConfirmation { .. } => None, + ToolCallStatus::Allowed { + status: acp::ToolCallStatus::Pending, + } + | ToolCallStatus::WaitingForConfirmation { .. } => None, ToolCallStatus::Allowed { status: acp::ToolCallStatus::InProgress, .. @@ -957,6 +960,8 @@ impl AcpThreadView { Icon::new(match tool_call.kind { acp::ToolKind::Read => IconName::ToolRead, acp::ToolKind::Edit => IconName::ToolPencil, + acp::ToolKind::Delete => IconName::ToolDeleteFile, + acp::ToolKind::Move => IconName::ArrowRightLeft, acp::ToolKind::Search => IconName::ToolSearch, acp::ToolKind::Execute => IconName::ToolTerminal, acp::ToolKind::Think => IconName::ToolBulb, @@ -1068,6 +1073,7 @@ impl AcpThreadView { options, entry_ix, tool_call.id.clone(), + tool_call.content.is_empty(), cx, )), ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => { @@ -1126,6 +1132,7 @@ impl AcpThreadView { options: &[acp::PermissionOption], entry_ix: usize, tool_call_id: acp::ToolCallId, + empty_content: bool, cx: &Context, ) -> Div { h_flex() @@ -1133,8 +1140,10 @@ impl AcpThreadView { .px_1p5() .gap_1() .justify_end() - .border_t_1() - .border_color(self.tool_card_border_color(cx)) + .when(!empty_content, |this| { + this.border_t_1() + .border_color(self.tool_card_border_color(cx)) + }) .children(options.iter().map(|option| { let option_id = SharedString::from(option.id.0.clone()); Button::new((option_id, entry_ix), option.label.clone()) diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 43c1167af8..61a65de50b 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1991,6 +1991,20 @@ impl AgentPanel { ); }), ) + .item( + ContextMenuEntry::new("New Codex Thread") + .icon(IconName::AiOpenAi) + .icon_color(Color::Muted) + .handler(move |window, cx| { + window.dispatch_action( + NewExternalAgentThread { + agent: Some(crate::ExternalAgent::Codex), + } + .boxed_clone(), + cx, + ); + }), + ) }); menu })) @@ -2652,6 +2666,25 @@ impl AgentPanel { ) }, ), + ) + .child( + NewThreadButton::new( + "new-codex-thread-btn", + "New Codex Thread", + IconName::AiOpenAi, + ) + .on_click( + |window, cx| { + window.dispatch_action( + Box::new(NewExternalAgentThread { + agent: Some( + crate::ExternalAgent::Codex, + ), + }), + cx, + ) + }, + ), ), ) }), diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 6ae78585de..4b75cc9e77 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -150,6 +150,7 @@ enum ExternalAgent { #[default] Gemini, ClaudeCode, + Codex, } impl ExternalAgent { @@ -157,6 +158,7 @@ impl ExternalAgent { match self { ExternalAgent::Gemini => Rc::new(agent_servers::Gemini), ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode), + ExternalAgent::Codex => Rc::new(agent_servers::Codex), } } } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 8c5e7da0f1..ff4d79c07d 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -330,23 +330,16 @@ impl Client { method: &str, params: impl Serialize, ) -> Result { - self.request_impl(method, params, None).await + self.request_with(method, params, None, Some(REQUEST_TIMEOUT)) + .await } - pub async fn cancellable_request( - &self, - method: &str, - params: impl Serialize, - cancel_rx: oneshot::Receiver<()>, - ) -> Result { - self.request_impl(method, params, Some(cancel_rx)).await - } - - pub async fn request_impl( + pub async fn request_with( &self, method: &str, params: impl Serialize, cancel_rx: Option>, + timeout: Option, ) -> Result { let id = self.next_id.fetch_add(1, SeqCst); let request = serde_json::to_string(&Request { @@ -382,7 +375,13 @@ impl Client { handle_response?; send?; - let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse(); + let mut timeout_fut = pin!( + match timeout { + Some(timeout) => future::Either::Left(executor.timer(timeout)), + None => future::Either::Right(future::pending()), + } + .fuse() + ); let mut cancel_fut = pin!( match cancel_rx { Some(rx) => future::Either::Left(async { @@ -419,10 +418,10 @@ impl Client { reason: None }) ).log_err(); - anyhow::bail!("Request cancelled") + anyhow::bail!(RequestCanceled) } - _ = timeout => { - log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT); + _ = timeout_fut => { + log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout.unwrap()); anyhow::bail!("Context server request timeout"); } } @@ -452,6 +451,17 @@ impl Client { } } +#[derive(Debug)] +pub struct RequestCanceled; + +impl std::error::Error for RequestCanceled {} + +impl std::fmt::Display for RequestCanceled { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Context server request was canceled") + } +} + impl fmt::Display for ContextServerId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs index 192f530816..34e3a9a78c 100644 --- a/crates/context_server/src/listener.rs +++ b/crates/context_server/src/listener.rs @@ -419,7 +419,7 @@ pub struct ToolResponse { pub structured_content: T, } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] struct RawRequest { #[serde(skip_serializing_if = "Option::is_none")] id: Option, diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 7263f502fa..9ccbc8a553 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -5,6 +5,8 @@ //! read/write messages and the types from types.rs for serialization/deserialization //! of messages. +use std::time::Duration; + use anyhow::Result; use futures::channel::oneshot; use gpui::AsyncApp; @@ -98,13 +100,14 @@ impl InitializedContextServerProtocol { self.inner.request(T::METHOD, params).await } - pub async fn cancellable_request( + pub async fn request_with( &self, params: T::Params, - cancel_rx: oneshot::Receiver<()>, + cancel_rx: Option>, + timeout: Option, ) -> Result { self.inner - .cancellable_request(T::METHOD, params, cancel_rx) + .request_with(T::METHOD, params, cancel_rx, timeout) .await } diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index cd97ff95bc..5fa2420a3d 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -626,6 +626,7 @@ pub enum ClientNotification { } #[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct CancelledParams { pub request_id: RequestId, #[serde(skip_serializing_if = "Option::is_none")] @@ -685,6 +686,18 @@ pub struct CallToolResponse { pub structured_content: Option, } +impl CallToolResponse { + pub fn text_contents(&self) -> String { + let mut text = String::new(); + for chunk in &self.content { + if let ToolResponseContent::Text { text: chunk } = chunk { + text.push_str(&chunk) + }; + } + text + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ToolResponseContent {