diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs index 4e2244ba9b..3c4c2f7473 100644 --- a/crates/agent_servers/src/codex.rs +++ b/crates/agent_servers/src/codex.rs @@ -69,7 +69,7 @@ pub struct PatchApprovalRequest { #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "codex_elicitation", rename_all = "snake_case")] -enum CodexElicitation { +pub enum CodexElicitation { ExecApproval(ExecApprovalRequest), PatchApproval(PatchApprovalRequest), } @@ -195,11 +195,12 @@ impl AgentServer for Codex { let client = codex_mcp_client .client() .context("Failed to subscribe to server")?; + client.on_notification("codex/event", { move |event, cx| { let mut notification_tx = notification_tx.clone(); cx.background_spawn(async move { - log::trace!("Notification: {:?}", event); + log::trace!("Notification: {:?}", serde_json::to_string_pretty(&event)); if let Some(event) = serde_json::from_value::(event).log_err() { notification_tx.send(event.msg).await.log_err(); } @@ -209,16 +210,19 @@ impl AgentServer for Codex { }); client.on_request::({ - let delegate = delegate.clone(); - { - move |elicitation, cx| { - let (tx, rx) = oneshot::channel::>(); - request_tx.send((elicitation, tx)); - cx.foreground_executor().spawn(rx) - } + move |elicitation, cx| { + let (tx, rx) = oneshot::channel::>(); + let mut request_tx = request_tx.clone(); + cx.background_spawn(async move { + log::trace!("Elicitation: {:?}", elicitation); + request_tx.send((elicitation, tx)).await?; + rx.await? + }) } }); + let requested_call_id = Rc::new(RefCell::new(None)); + cx.new(|cx| { let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()); delegate_tx.send(Some(delegate.clone())).log_err(); @@ -226,12 +230,14 @@ impl AgentServer for Codex { let handler_task = cx.spawn({ let delegate = delegate.clone(); let tool_id_map = tool_id_map.clone(); + let requested_call_id = requested_call_id.clone(); async move |_, _cx| { while let Some(notification) = notification_rx.next().await { CodexAgentConnection::handle_acp_notification( &delegate, notification, &tool_id_map, + &requested_call_id, ) .await .log_err(); @@ -241,83 +247,20 @@ impl AgentServer for Codex { let request_task = cx.spawn({ let delegate = delegate.clone(); - let tool_id_map = tool_id_map.clone(); async move |_, _cx| { - while let Some((elicitation, respond)) = request_tx.next().await { - let confirmation = match elicitation { - CodexElicitation::ExecApproval(exec) => { - let inner_command = - strip_bash_lc_and_escape(&exec.codex_command); + while let Some((elicitation, respond)) = request_rx.next().await { + if let Some((id, decision)) = + CodexAgentConnection::handle_elicitation(&delegate, elicitation) + .await + .log_err() + { + requested_call_id.replace(Some(id)); - acp::RequestToolCallConfirmationParams { - tool_call: acp::PushToolCallParams { - label: todo!(), - icon: acp::Icon::Terminal, - content: None, - locations: vec![], - }, - confirmation: acp::ToolCallConfirmation::Execute { - root_command: inner_command - .split(" ") - .next() - .unwrap_or_default() - .to_string(), - command: inner_command, - description: Some(exec.message), - }, - } - } - CodexElicitation::PatchApproval(patch) => { - acp::RequestToolCallConfirmationParams { - tool_call: acp::PushToolCallParams { - label: "Edit".to_string(), - icon: acp::Icon::Pencil, - content: None, // todo!() - locations: patch - .codex_changes - .keys() - .map(|path| acp::ToolCallLocation { - path: path.clone(), - line: None, - }) - .collect(), - }, - confirmation: acp::ToolCallConfirmation::Edit { - description: Some(patch.message), - }, - } - } - }; - - let task = cx.spawn(async move |cx| { - let response = delegate - .request_tool_call_confirmation(confirmation) - .await?; - - let decision = match response.outcome { - acp::ToolCallConfirmationOutcome::Allow => { - ReviewDecision::Approved - } - acp::ToolCallConfirmationOutcome::AlwaysAllow - | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer - | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => { - ReviewDecision::ApprovedForSession - } - acp::ToolCallConfirmationOutcome::Reject => { - ReviewDecision::Denied - } - acp::ToolCallConfirmationOutcome::Cancel => { - ReviewDecision::Abort - } - }; - - Ok(CodexApprovalResponse { decision }) - }); + respond + .send(Ok(CodexApprovalResponse { decision })) + .log_err(); + } } - - cx.spawn(async move |cx| { - tx.send(task.await).ok(); - }) } }); @@ -325,7 +268,6 @@ impl AgentServer for Codex { root_dir, codex_mcp: codex_mcp_client, cancel_request_tx: Default::default(), - tool_id_map: tool_id_map.clone(), _handler_task: handler_task, _request_task: request_task, _zed_mcp: zed_mcp_server, @@ -413,17 +355,83 @@ struct CodexAgentConnection { codex_mcp: Arc, root_dir: PathBuf, cancel_request_tx: Rc>>>, - tool_id_map: Rc>>, _handler_task: Task<()>, _request_task: Task<()>, _zed_mcp: ZedMcpServer, } impl CodexAgentConnection { + async fn handle_elicitation( + delegate: &AcpClientDelegate, + elicitation: CodexElicitation, + ) -> Result<(acp::ToolCallId, ReviewDecision)> { + let confirmation = match elicitation { + CodexElicitation::ExecApproval(exec) => { + let inner_command = strip_bash_lc_and_escape(&exec.codex_command); + + acp::RequestToolCallConfirmationParams { + tool_call: acp::PushToolCallParams { + label: format!("`{inner_command}`"), + icon: acp::Icon::Terminal, + content: None, + locations: vec![], + }, + confirmation: acp::ToolCallConfirmation::Execute { + root_command: inner_command + .split(" ") + .next() + .unwrap_or_default() + .to_string(), + command: inner_command, + description: Some(exec.message), + }, + } + } + CodexElicitation::PatchApproval(patch) => { + acp::RequestToolCallConfirmationParams { + tool_call: acp::PushToolCallParams { + label: "Edit".to_string(), + icon: acp::Icon::Pencil, + content: None, // todo!() + locations: patch + .codex_changes + .keys() + .map(|path| acp::ToolCallLocation { + path: path.clone(), + line: None, + }) + .collect(), + }, + confirmation: acp::ToolCallConfirmation::Edit { + description: Some(patch.message), + }, + } + } + }; + + let response = delegate + .request_tool_call_confirmation(confirmation) + .await?; + + let decision = match response.outcome { + acp::ToolCallConfirmationOutcome::Allow => ReviewDecision::Approved, + acp::ToolCallConfirmationOutcome::AlwaysAllow + | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer + | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => { + ReviewDecision::ApprovedForSession + } + acp::ToolCallConfirmationOutcome::Reject => ReviewDecision::Denied, + acp::ToolCallConfirmationOutcome::Cancel => ReviewDecision::Abort, + }; + + Ok((response.id, decision)) + } + async fn handle_acp_notification( delegate: &AcpClientDelegate, event: AcpNotification, tool_id_map: &Rc>>, + requested_call_id: &Rc>>, ) -> Result<()> { match event { AcpNotification::AgentMessage(message) => { @@ -445,23 +453,29 @@ impl CodexAgentConnection { .await? } AcpNotification::McpToolCallBegin(event) => { - let result = delegate - .push_tool_call(acp::PushToolCallParams { - label: format!("`{}: {}`", event.server, event.tool), - icon: acp::Icon::Hammer, - content: event.arguments.and_then(|args| { - Some(acp::ToolCallContent::Markdown { - markdown: md_codeblock( - "json", - &serde_json::to_string_pretty(&args).ok()?, - ), - }) - }), - locations: vec![], - }) - .await?; + if let Some(requested_tool_id) = requested_call_id.take() { + tool_id_map + .borrow_mut() + .insert(event.call_id, requested_tool_id); + } else { + let result = delegate + .push_tool_call(acp::PushToolCallParams { + label: format!("`{}: {}`", event.server, event.tool), + icon: acp::Icon::Hammer, + content: event.arguments.and_then(|args| { + Some(acp::ToolCallContent::Markdown { + markdown: md_codeblock( + "json", + &serde_json::to_string_pretty(&args).ok()?, + ), + }) + }), + locations: vec![], + }) + .await?; - tool_id_map.borrow_mut().insert(event.call_id, result.id); + tool_id_map.borrow_mut().insert(event.call_id, result.id); + } } AcpNotification::McpToolCallEnd(event) => { let acp_call_id = tool_id_map @@ -502,18 +516,24 @@ impl CodexAgentConnection { .await?; } AcpNotification::ExecCommandBegin(event) => { - let inner_command = strip_bash_lc_and_escape(&event.command); + if let Some(requested_tool_id) = requested_call_id.take() { + tool_id_map + .borrow_mut() + .insert(event.call_id, requested_tool_id); + } else { + let inner_command = strip_bash_lc_and_escape(&event.command); - let result = delegate - .push_tool_call(acp::PushToolCallParams { - label: format!("`{}`", inner_command), - icon: acp::Icon::Terminal, - content: None, - locations: vec![], - }) - .await?; + let result = delegate + .push_tool_call(acp::PushToolCallParams { + label: format!("`{}`", inner_command), + icon: acp::Icon::Terminal, + content: None, + locations: vec![], + }) + .await?; - tool_id_map.borrow_mut().insert(event.call_id, result.id); + tool_id_map.borrow_mut().insert(event.call_id, result.id); + } } AcpNotification::ExecCommandEnd(event) => { let acp_call_id = tool_id_map @@ -562,34 +582,6 @@ impl CodexAgentConnection { }) .await?; } - AcpNotification::ExecApprovalRequest(event) => { - let inner_command = strip_bash_lc_and_escape(&event.command); - let root_command = inner_command - .split(" ") - .next() - .map(|s| s.to_string()) - .unwrap_or_default(); - - let response = delegate - .request_tool_call_confirmation(acp::RequestToolCallConfirmationParams { - tool_call: acp::PushToolCallParams { - label: format!("`{}`", inner_command), - icon: acp::Icon::Terminal, - content: None, - locations: vec![], - }, - confirmation: acp::ToolCallConfirmation::Execute { - command: inner_command, - root_command, - description: event.reason, - }, - }) - .await?; - - tool_id_map.borrow_mut().insert(event.call_id, response.id); - - // todo! approval - } AcpNotification::Other => {} } @@ -620,7 +612,6 @@ pub enum AcpNotification { McpToolCallEnd(McpToolCallEndEvent), ExecCommandBegin(ExecCommandBeginEvent), ExecCommandEnd(ExecCommandEndEvent), - ExecApprovalRequest(ExecApprovalRequestEvent), #[serde(other)] Other, }