Wire up elicitations

This commit is contained in:
Agus Zubiaga 2025-07-22 17:34:02 -03:00
parent 966d29dcd9
commit cedd6aa704

View file

@ -69,7 +69,7 @@ pub struct PatchApprovalRequest {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "codex_elicitation", rename_all = "snake_case")] #[serde(tag = "codex_elicitation", rename_all = "snake_case")]
enum CodexElicitation { pub enum CodexElicitation {
ExecApproval(ExecApprovalRequest), ExecApproval(ExecApprovalRequest),
PatchApproval(PatchApprovalRequest), PatchApproval(PatchApprovalRequest),
} }
@ -195,11 +195,12 @@ impl AgentServer for Codex {
let client = codex_mcp_client let client = codex_mcp_client
.client() .client()
.context("Failed to subscribe to server")?; .context("Failed to subscribe to server")?;
client.on_notification("codex/event", { client.on_notification("codex/event", {
move |event, cx| { move |event, cx| {
let mut notification_tx = notification_tx.clone(); let mut notification_tx = notification_tx.clone();
cx.background_spawn(async move { cx.background_spawn(async move {
log::trace!("Notification: {:?}", event); log::trace!("Notification: {:?}", serde_json::to_string_pretty(&event));
if let Some(event) = serde_json::from_value::<CodexEvent>(event).log_err() { if let Some(event) = serde_json::from_value::<CodexEvent>(event).log_err() {
notification_tx.send(event.msg).await.log_err(); notification_tx.send(event.msg).await.log_err();
} }
@ -209,16 +210,19 @@ impl AgentServer for Codex {
}); });
client.on_request::<CodexApproval, _>({ client.on_request::<CodexApproval, _>({
let delegate = delegate.clone(); move |elicitation, cx| {
{ let (tx, rx) = oneshot::channel::<Result<CodexApprovalResponse>>();
move |elicitation, cx| { let mut request_tx = request_tx.clone();
let (tx, rx) = oneshot::channel::<Result<CodexApprovalResponse>>(); cx.background_spawn(async move {
request_tx.send((elicitation, tx)); log::trace!("Elicitation: {:?}", elicitation);
cx.foreground_executor().spawn(rx) request_tx.send((elicitation, tx)).await?;
} rx.await?
})
} }
}); });
let requested_call_id = Rc::new(RefCell::new(None));
cx.new(|cx| { cx.new(|cx| {
let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()); let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
delegate_tx.send(Some(delegate.clone())).log_err(); delegate_tx.send(Some(delegate.clone())).log_err();
@ -226,12 +230,14 @@ impl AgentServer for Codex {
let handler_task = cx.spawn({ let handler_task = cx.spawn({
let delegate = delegate.clone(); let delegate = delegate.clone();
let tool_id_map = tool_id_map.clone(); let tool_id_map = tool_id_map.clone();
let requested_call_id = requested_call_id.clone();
async move |_, _cx| { async move |_, _cx| {
while let Some(notification) = notification_rx.next().await { while let Some(notification) = notification_rx.next().await {
CodexAgentConnection::handle_acp_notification( CodexAgentConnection::handle_acp_notification(
&delegate, &delegate,
notification, notification,
&tool_id_map, &tool_id_map,
&requested_call_id,
) )
.await .await
.log_err(); .log_err();
@ -241,83 +247,20 @@ impl AgentServer for Codex {
let request_task = cx.spawn({ let request_task = cx.spawn({
let delegate = delegate.clone(); let delegate = delegate.clone();
let tool_id_map = tool_id_map.clone();
async move |_, _cx| { async move |_, _cx| {
while let Some((elicitation, respond)) = request_tx.next().await { while let Some((elicitation, respond)) = request_rx.next().await {
let confirmation = match elicitation { if let Some((id, decision)) =
CodexElicitation::ExecApproval(exec) => { CodexAgentConnection::handle_elicitation(&delegate, elicitation)
let inner_command = .await
strip_bash_lc_and_escape(&exec.codex_command); .log_err()
{
requested_call_id.replace(Some(id));
acp::RequestToolCallConfirmationParams { respond
tool_call: acp::PushToolCallParams { .send(Ok(CodexApprovalResponse { decision }))
label: todo!(), .log_err();
icon: acp::Icon::Terminal, }
content: None,
locations: vec![],
},
confirmation: acp::ToolCallConfirmation::Execute {
root_command: inner_command
.split(" ")
.next()
.unwrap_or_default()
.to_string(),
command: inner_command,
description: Some(exec.message),
},
}
}
CodexElicitation::PatchApproval(patch) => {
acp::RequestToolCallConfirmationParams {
tool_call: acp::PushToolCallParams {
label: "Edit".to_string(),
icon: acp::Icon::Pencil,
content: None, // todo!()
locations: patch
.codex_changes
.keys()
.map(|path| acp::ToolCallLocation {
path: path.clone(),
line: None,
})
.collect(),
},
confirmation: acp::ToolCallConfirmation::Edit {
description: Some(patch.message),
},
}
}
};
let task = cx.spawn(async move |cx| {
let response = delegate
.request_tool_call_confirmation(confirmation)
.await?;
let decision = match response.outcome {
acp::ToolCallConfirmationOutcome::Allow => {
ReviewDecision::Approved
}
acp::ToolCallConfirmationOutcome::AlwaysAllow
| acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
| acp::ToolCallConfirmationOutcome::AlwaysAllowTool => {
ReviewDecision::ApprovedForSession
}
acp::ToolCallConfirmationOutcome::Reject => {
ReviewDecision::Denied
}
acp::ToolCallConfirmationOutcome::Cancel => {
ReviewDecision::Abort
}
};
Ok(CodexApprovalResponse { decision })
});
} }
cx.spawn(async move |cx| {
tx.send(task.await).ok();
})
} }
}); });
@ -325,7 +268,6 @@ impl AgentServer for Codex {
root_dir, root_dir,
codex_mcp: codex_mcp_client, codex_mcp: codex_mcp_client,
cancel_request_tx: Default::default(), cancel_request_tx: Default::default(),
tool_id_map: tool_id_map.clone(),
_handler_task: handler_task, _handler_task: handler_task,
_request_task: request_task, _request_task: request_task,
_zed_mcp: zed_mcp_server, _zed_mcp: zed_mcp_server,
@ -413,17 +355,83 @@ struct CodexAgentConnection {
codex_mcp: Arc<context_server::ContextServer>, codex_mcp: Arc<context_server::ContextServer>,
root_dir: PathBuf, root_dir: PathBuf,
cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>, cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
_handler_task: Task<()>, _handler_task: Task<()>,
_request_task: Task<()>, _request_task: Task<()>,
_zed_mcp: ZedMcpServer, _zed_mcp: ZedMcpServer,
} }
impl CodexAgentConnection { impl CodexAgentConnection {
async fn handle_elicitation(
delegate: &AcpClientDelegate,
elicitation: CodexElicitation,
) -> Result<(acp::ToolCallId, ReviewDecision)> {
let confirmation = match elicitation {
CodexElicitation::ExecApproval(exec) => {
let inner_command = strip_bash_lc_and_escape(&exec.codex_command);
acp::RequestToolCallConfirmationParams {
tool_call: acp::PushToolCallParams {
label: format!("`{inner_command}`"),
icon: acp::Icon::Terminal,
content: None,
locations: vec![],
},
confirmation: acp::ToolCallConfirmation::Execute {
root_command: inner_command
.split(" ")
.next()
.unwrap_or_default()
.to_string(),
command: inner_command,
description: Some(exec.message),
},
}
}
CodexElicitation::PatchApproval(patch) => {
acp::RequestToolCallConfirmationParams {
tool_call: acp::PushToolCallParams {
label: "Edit".to_string(),
icon: acp::Icon::Pencil,
content: None, // todo!()
locations: patch
.codex_changes
.keys()
.map(|path| acp::ToolCallLocation {
path: path.clone(),
line: None,
})
.collect(),
},
confirmation: acp::ToolCallConfirmation::Edit {
description: Some(patch.message),
},
}
}
};
let response = delegate
.request_tool_call_confirmation(confirmation)
.await?;
let decision = match response.outcome {
acp::ToolCallConfirmationOutcome::Allow => ReviewDecision::Approved,
acp::ToolCallConfirmationOutcome::AlwaysAllow
| acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
| acp::ToolCallConfirmationOutcome::AlwaysAllowTool => {
ReviewDecision::ApprovedForSession
}
acp::ToolCallConfirmationOutcome::Reject => ReviewDecision::Denied,
acp::ToolCallConfirmationOutcome::Cancel => ReviewDecision::Abort,
};
Ok((response.id, decision))
}
async fn handle_acp_notification( async fn handle_acp_notification(
delegate: &AcpClientDelegate, delegate: &AcpClientDelegate,
event: AcpNotification, event: AcpNotification,
tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>, tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
requested_call_id: &Rc<RefCell<Option<acp::ToolCallId>>>,
) -> Result<()> { ) -> Result<()> {
match event { match event {
AcpNotification::AgentMessage(message) => { AcpNotification::AgentMessage(message) => {
@ -445,23 +453,29 @@ impl CodexAgentConnection {
.await? .await?
} }
AcpNotification::McpToolCallBegin(event) => { AcpNotification::McpToolCallBegin(event) => {
let result = delegate if let Some(requested_tool_id) = requested_call_id.take() {
.push_tool_call(acp::PushToolCallParams { tool_id_map
label: format!("`{}: {}`", event.server, event.tool), .borrow_mut()
icon: acp::Icon::Hammer, .insert(event.call_id, requested_tool_id);
content: event.arguments.and_then(|args| { } else {
Some(acp::ToolCallContent::Markdown { let result = delegate
markdown: md_codeblock( .push_tool_call(acp::PushToolCallParams {
"json", label: format!("`{}: {}`", event.server, event.tool),
&serde_json::to_string_pretty(&args).ok()?, icon: acp::Icon::Hammer,
), content: event.arguments.and_then(|args| {
}) Some(acp::ToolCallContent::Markdown {
}), markdown: md_codeblock(
locations: vec![], "json",
}) &serde_json::to_string_pretty(&args).ok()?,
.await?; ),
})
}),
locations: vec![],
})
.await?;
tool_id_map.borrow_mut().insert(event.call_id, result.id); tool_id_map.borrow_mut().insert(event.call_id, result.id);
}
} }
AcpNotification::McpToolCallEnd(event) => { AcpNotification::McpToolCallEnd(event) => {
let acp_call_id = tool_id_map let acp_call_id = tool_id_map
@ -502,18 +516,24 @@ impl CodexAgentConnection {
.await?; .await?;
} }
AcpNotification::ExecCommandBegin(event) => { AcpNotification::ExecCommandBegin(event) => {
let inner_command = strip_bash_lc_and_escape(&event.command); if let Some(requested_tool_id) = requested_call_id.take() {
tool_id_map
.borrow_mut()
.insert(event.call_id, requested_tool_id);
} else {
let inner_command = strip_bash_lc_and_escape(&event.command);
let result = delegate let result = delegate
.push_tool_call(acp::PushToolCallParams { .push_tool_call(acp::PushToolCallParams {
label: format!("`{}`", inner_command), label: format!("`{}`", inner_command),
icon: acp::Icon::Terminal, icon: acp::Icon::Terminal,
content: None, content: None,
locations: vec![], locations: vec![],
}) })
.await?; .await?;
tool_id_map.borrow_mut().insert(event.call_id, result.id); tool_id_map.borrow_mut().insert(event.call_id, result.id);
}
} }
AcpNotification::ExecCommandEnd(event) => { AcpNotification::ExecCommandEnd(event) => {
let acp_call_id = tool_id_map let acp_call_id = tool_id_map
@ -562,34 +582,6 @@ impl CodexAgentConnection {
}) })
.await?; .await?;
} }
AcpNotification::ExecApprovalRequest(event) => {
let inner_command = strip_bash_lc_and_escape(&event.command);
let root_command = inner_command
.split(" ")
.next()
.map(|s| s.to_string())
.unwrap_or_default();
let response = delegate
.request_tool_call_confirmation(acp::RequestToolCallConfirmationParams {
tool_call: acp::PushToolCallParams {
label: format!("`{}`", inner_command),
icon: acp::Icon::Terminal,
content: None,
locations: vec![],
},
confirmation: acp::ToolCallConfirmation::Execute {
command: inner_command,
root_command,
description: event.reason,
},
})
.await?;
tool_id_map.borrow_mut().insert(event.call_id, response.id);
// todo! approval
}
AcpNotification::Other => {} AcpNotification::Other => {}
} }
@ -620,7 +612,6 @@ pub enum AcpNotification {
McpToolCallEnd(McpToolCallEndEvent), McpToolCallEnd(McpToolCallEndEvent),
ExecCommandBegin(ExecCommandBeginEvent), ExecCommandBegin(ExecCommandBeginEvent),
ExecCommandEnd(ExecCommandEndEvent), ExecCommandEnd(ExecCommandEndEvent),
ExecApprovalRequest(ExecApprovalRequestEvent),
#[serde(other)] #[serde(other)]
Other, Other,
} }