Compare commits

...
Sign in to create a new pull request.

20 commits

Author SHA1 Message Date
Agus Zubiaga
2a00a53fcf Clean up 2025-07-22 19:26:18 -03:00
Agus Zubiaga
1e5625c4b4 Merge branch 'main' into mcp-codex 2025-07-22 19:24:10 -03:00
Agus Zubiaga
20c0a06485 Use codex-reply on the second request 2025-07-22 19:21:48 -03:00
Agus Zubiaga
598a8180b5 Initialize codex with MCP server 2025-07-22 19:03:28 -03:00
Agus Zubiaga
07a08f0972 Fix elicitation decoding 2025-07-22 17:43:28 -03:00
Agus Zubiaga
cedd6aa704 Wire up elicitations 2025-07-22 17:34:02 -03:00
Conrad Irwin
966d29dcd9 WIp 2025-07-22 13:49:28 -06:00
Conrad Irwin
cede9d757a Eliciting better 2025-07-22 13:10:06 -06:00
Conrad Irwin
c0c698b883 WIP elciication
t
2025-07-22 12:29:27 -06:00
Ben Brandt
0fa7d58a3e
Pass Zed MCP Server as command args 2025-07-22 15:11:48 +02:00
Ben Brandt
9b91445967
Use pathbuf in McpServerConfig 2025-07-22 15:06:39 +02:00
Ben Brandt
480adade63
Add correct codex path and setup e2e tests (not running yet) 2025-07-22 14:13:05 +02:00
Ben Brandt
47dec0df99
Add local command for codex 2025-07-22 13:56:55 +02:00
Ben Brandt
f20edf1b50
Remove unneeded todo 2025-07-22 12:20:52 +02:00
Ben Brandt
03b94f5831
Merge branch 'main' into mcp-codex 2025-07-22 12:20:13 +02:00
Agus Zubiaga
e7298c0736 Display tool calls 2025-07-21 20:23:27 -03:00
Agus Zubiaga
a822711e99 Stop generation 2025-07-21 19:46:14 -03:00
Agus Zubiaga
4b1ace9a54 Handle Agent Reasoning 2025-07-21 18:29:13 -03:00
Agus Zubiaga
769d6dc632 Stream text 2025-07-21 18:26:28 -03:00
Agus Zubiaga
f56910556f Connect to Codex MCP server 2025-07-21 15:19:50 -03:00
15 changed files with 1114 additions and 74 deletions

2
Cargo.lock generated
View file

@ -159,6 +159,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
"shlex",
"smol",
"strum 0.27.1",
"tempfile",
@ -20204,6 +20205,7 @@ dependencies = [
"diagnostics",
"editor",
"env_logger 0.11.8",
"erased-serde",
"extension",
"extension_host",
"extensions_ui",

View file

@ -40,6 +40,7 @@ util.workspace = true
uuid.workspace = true
watch.workspace = true
which.workspace = true
shlex.workspace = true
workspace-hack.workspace = true
[target.'cfg(unix)'.dependencies]

View file

@ -1,5 +1,7 @@
mod claude;
mod codex;
mod gemini;
mod mcp_server;
mod settings;
mod stdio_agent_server;
@ -7,6 +9,7 @@ mod stdio_agent_server;
mod e2e_tests;
pub use claude::*;
pub use codex::*;
pub use gemini::*;
pub use settings::*;
pub use stdio_agent_server::*;

View file

@ -1,5 +1,4 @@
mod mcp_server;
mod tools;
pub mod tools;
use collections::HashMap;
use project::Project;
@ -30,8 +29,8 @@ use gpui::{App, AppContext, Entity, Task};
use serde::{Deserialize, Serialize};
use util::ResultExt;
use crate::claude::mcp_server::ClaudeMcpServer;
use crate::claude::tools::ClaudeTool;
use crate::mcp_server::{self, McpConfig, ZedMcpServer};
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
@ -72,12 +71,13 @@ impl AgentServer for ClaudeCode {
let (mut delegate_tx, delegate_rx) = watch::channel(None);
let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
let permission_mcp_server =
ZedMcpServer::new(delegate_rx, tool_id_map.clone(), Default::default(), cx).await?;
let mut mcp_servers = HashMap::default();
mcp_servers.insert(
mcp_server::SERVER_NAME.to_string(),
mcp_server.server_config()?,
crate::mcp_server::SERVER_NAME.to_string(),
permission_mcp_server.server_config()?,
);
let mcp_config = McpConfig { mcp_servers };
@ -187,7 +187,7 @@ impl AgentServer for ClaudeCode {
_mcp_server: None,
};
connection._mcp_server = Some(mcp_server);
connection._mcp_server = Some(permission_mcp_server);
acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
})
})
@ -333,7 +333,7 @@ struct ClaudeAgentConnection {
outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>,
_mcp_server: Option<ClaudeMcpServer>,
_mcp_server: Option<ZedMcpServer>,
_handler_task: Task<()>,
}
@ -661,21 +661,6 @@ enum PermissionMode {
Plan,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct McpConfig {
mcp_servers: HashMap<String, McpServerConfig>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct McpServerConfig {
command: String,
args: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
env: Option<HashMap<String, String>>,
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;

View file

@ -0,0 +1,808 @@
use collections::HashMap;
use context_server::types::requests::CallTool;
use context_server::types::{CallToolParams, ToolResponseContent};
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use project::Project;
use settings::SettingsStore;
use smol::stream::StreamExt;
use std::cell::RefCell;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::Arc;
use uuid::Uuid;
use agentic_coding_protocol::{
self as acp, AnyAgentRequest, AnyAgentResult, Client as _, ProtocolVersion,
};
use anyhow::{Context, Result, anyhow};
use futures::future::LocalBoxFuture;
use futures::{FutureExt, SinkExt as _};
use gpui::{App, AppContext, Entity, Task};
use serde::{Deserialize, Serialize};
use util::ResultExt;
use crate::mcp_server::{self, McpServerConfig, ZedMcpServer};
use crate::tools::{EditToolParams, ReadToolParams};
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
#[derive(Clone)]
pub struct Codex;
impl AgentServer for Codex {
fn name(&self) -> &'static str {
"Codex"
}
fn empty_state_headline(&self) -> &'static str {
self.name()
}
fn empty_state_message(&self) -> &'static str {
""
}
fn logo(&self) -> ui::IconName {
ui::IconName::AiOpenAi
}
fn supports_always_allow(&self) -> bool {
false
}
fn new_thread(
&self,
root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let project = project.clone();
let root_dir = root_dir.to_path_buf();
let title = self.name().into();
cx.spawn(async move |cx| {
let (mut delegate_tx, delegate_rx) = watch::channel(None);
let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
let zed_mcp_server = ZedMcpServer::new(
delegate_rx,
tool_id_map.clone(),
mcp_server::EnabledTools {
permission: false,
..Default::default()
},
cx,
)
.await?;
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).codex.clone()
})?;
let Some(command) =
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
else {
anyhow::bail!("Failed to find codex binary");
};
let codex_mcp_client: Arc<ContextServer> = ContextServer::stdio(
ContextServerId("codex-mcp-server".into()),
ContextServerCommand {
path: command.path,
args: command.args,
env: command.env,
},
)
.into();
ContextServer::start(codex_mcp_client.clone(), cx).await?;
// todo! stop
let (notification_tx, mut notification_rx) = mpsc::unbounded();
let (request_tx, mut request_rx) = mpsc::unbounded();
let client = codex_mcp_client
.client()
.context("Failed to subscribe to server")?;
client.on_notification("codex/event", {
move |event, cx| {
let mut notification_tx = notification_tx.clone();
cx.background_spawn(async move {
log::trace!("Notification: {:?}", serde_json::to_string_pretty(&event));
if let Some(event) = serde_json::from_value::<CodexEvent>(event).log_err() {
notification_tx.send(event.msg).await.log_err();
}
})
.detach();
}
});
client.on_request::<CodexApproval, _>({
move |elicitation, cx| {
let (tx, rx) = oneshot::channel::<Result<CodexApprovalResponse>>();
let mut request_tx = request_tx.clone();
cx.background_spawn(async move {
log::trace!("Elicitation: {:?}", elicitation);
request_tx.send((elicitation, tx)).await?;
rx.await?
})
}
});
let requested_call_id = Rc::new(RefCell::new(None));
let session_id = Rc::new(RefCell::new(None));
cx.new(|cx| {
let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
delegate_tx.send(Some(delegate.clone())).log_err();
let handler_task = cx.spawn({
let delegate = delegate.clone();
let tool_id_map = tool_id_map.clone();
let requested_call_id = requested_call_id.clone();
let session_id = session_id.clone();
async move |_, _cx| {
while let Some(notification) = notification_rx.next().await {
CodexAgentConnection::handle_acp_notification(
&delegate,
notification,
&session_id,
&tool_id_map,
&requested_call_id,
)
.await
.log_err();
}
}
});
let request_task = cx.spawn({
let delegate = delegate.clone();
async move |_, _cx| {
while let Some((elicitation, respond)) = request_rx.next().await {
if let Some((id, decision)) =
CodexAgentConnection::handle_elicitation(&delegate, elicitation)
.await
.log_err()
{
requested_call_id.replace(Some(id));
respond
.send(Ok(CodexApprovalResponse { decision }))
.log_err();
}
}
}
});
let connection = CodexAgentConnection {
root_dir,
codex_mcp: codex_mcp_client,
cancel_request_tx: Default::default(),
session_id,
zed_mcp_server,
_handler_task: handler_task,
_request_task: request_task,
};
acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
})
})
}
}
struct CodexAgentConnection {
codex_mcp: Arc<context_server::ContextServer>,
root_dir: PathBuf,
cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
session_id: Rc<RefCell<Option<Uuid>>>,
zed_mcp_server: ZedMcpServer,
_handler_task: Task<()>,
_request_task: Task<()>,
}
impl AgentConnection for CodexAgentConnection {
/// Send a request to the agent and wait for a response.
fn request_any(
&self,
params: AnyAgentRequest,
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
let client = self.codex_mcp.client();
let root_dir = self.root_dir.clone();
let cancel_request_tx = self.cancel_request_tx.clone();
let mcp_config = self.zed_mcp_server.server_config();
let session_id = self.session_id.clone();
async move {
let client = client.context("Codex MCP server is not initialized")?;
match params {
// todo: consider sending an empty request so we get the init response?
AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
acp::InitializeResponse {
is_authenticated: true,
protocol_version: ProtocolVersion::latest(),
},
)),
AnyAgentRequest::AuthenticateParams(_) => {
Err(anyhow!("Authentication not supported"))
}
AnyAgentRequest::SendUserMessageParams(message) => {
let (new_cancel_tx, cancel_rx) = oneshot::channel();
cancel_request_tx.borrow_mut().replace(new_cancel_tx);
let prompt = message
.chunks
.into_iter()
.filter_map(|chunk| match chunk {
acp::UserMessageChunk::Text { text } => Some(text),
acp::UserMessageChunk::Path { .. } => {
// todo!
None
}
})
.collect();
let params = if let Some(session_id) = *session_id.borrow() {
CallToolParams {
name: "codex-reply".into(),
arguments: Some(serde_json::to_value(CodexToolCallReplyParam {
prompt,
session_id,
})?),
meta: None,
}
} else {
CallToolParams {
name: "codex".into(),
arguments: Some(serde_json::to_value(CodexToolCallParam {
prompt,
cwd: root_dir,
config: Some(CodexConfig {
mcp_servers: Some(
mcp_config
.into_iter()
.map(|config| {
(mcp_server::SERVER_NAME.to_string(), config)
})
.collect(),
),
}),
})?),
meta: None,
}
};
client
.request_with::<CallTool>(params, Some(cancel_rx), None)
.await?;
Ok(AnyAgentResult::SendUserMessageResponse(
acp::SendUserMessageResponse,
))
}
AnyAgentRequest::CancelSendMessageParams(_) => {
if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() {
if let Some(cancel_tx) = borrow.take() {
cancel_tx.send(()).ok();
}
}
Ok(AnyAgentResult::CancelSendMessageResponse(
acp::CancelSendMessageResponse,
))
}
}
}
.boxed_local()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CodexConfig {
mcp_servers: Option<HashMap<String, McpServerConfig>>,
}
impl CodexAgentConnection {
async fn handle_elicitation(
delegate: &AcpClientDelegate,
elicitation: CodexElicitation,
) -> Result<(acp::ToolCallId, ReviewDecision)> {
let confirmation = match elicitation {
CodexElicitation::ExecApproval(exec) => {
let inner_command = strip_bash_lc_and_escape(&exec.codex_command);
acp::RequestToolCallConfirmationParams {
tool_call: acp::PushToolCallParams {
label: format!("`{inner_command}`"),
icon: acp::Icon::Terminal,
content: None,
locations: vec![],
},
confirmation: acp::ToolCallConfirmation::Execute {
root_command: inner_command
.split(" ")
.next()
.unwrap_or_default()
.to_string(),
command: inner_command,
description: Some(exec.message),
},
}
}
CodexElicitation::PatchApproval(patch) => {
acp::RequestToolCallConfirmationParams {
tool_call: acp::PushToolCallParams {
label: "Edit".to_string(),
icon: acp::Icon::Pencil,
content: None, // todo!()
locations: patch
.codex_changes
.keys()
.map(|path| acp::ToolCallLocation {
path: path.clone(),
line: None,
})
.collect(),
},
confirmation: acp::ToolCallConfirmation::Edit {
description: Some(patch.message),
},
}
}
};
let response = delegate
.request_tool_call_confirmation(confirmation)
.await?;
let decision = match response.outcome {
acp::ToolCallConfirmationOutcome::Allow => ReviewDecision::Approved,
acp::ToolCallConfirmationOutcome::AlwaysAllow
| acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
| acp::ToolCallConfirmationOutcome::AlwaysAllowTool => {
ReviewDecision::ApprovedForSession
}
acp::ToolCallConfirmationOutcome::Reject => ReviewDecision::Denied,
acp::ToolCallConfirmationOutcome::Cancel => ReviewDecision::Abort,
};
Ok((response.id, decision))
}
async fn handle_acp_notification(
delegate: &AcpClientDelegate,
event: AcpNotification,
session_id: &Rc<RefCell<Option<Uuid>>>,
tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
requested_call_id: &Rc<RefCell<Option<acp::ToolCallId>>>,
) -> Result<()> {
match event {
AcpNotification::SessionConfigured(sesh) => {
session_id.replace(Some(sesh.session_id));
}
AcpNotification::AgentMessage(message) => {
delegate
.stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
chunk: acp::AssistantMessageChunk::Text {
text: message.message,
},
})
.await?;
}
AcpNotification::AgentReasoning(message) => {
delegate
.stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
chunk: acp::AssistantMessageChunk::Thought {
thought: message.text,
},
})
.await?
}
AcpNotification::McpToolCallBegin(mut event) => {
if let Some(requested_tool_id) = requested_call_id.take() {
tool_id_map
.borrow_mut()
.insert(event.call_id, requested_tool_id);
} else {
let mut tool_call = acp::PushToolCallParams {
label: format!("`{}: {}`", event.server, event.tool),
icon: acp::Icon::Hammer,
content: event.arguments.as_ref().and_then(|args| {
Some(acp::ToolCallContent::Markdown {
markdown: md_codeblock(
"json",
&serde_json::to_string_pretty(args).ok()?,
),
})
}),
locations: vec![],
};
if event.server == mcp_server::SERVER_NAME
&& event.tool == mcp_server::EDIT_TOOL
&& let Some(params) = event.arguments.take().and_then(|args| {
serde_json::from_value::<EditToolParams>(args).log_err()
})
{
tool_call = acp::PushToolCallParams {
label: "Edit".into(),
icon: acp::Icon::Pencil,
content: Some(acp::ToolCallContent::Diff {
diff: acp::Diff {
path: params.abs_path.clone(),
old_text: Some(params.old_text),
new_text: params.new_text,
},
}),
locations: vec![acp::ToolCallLocation {
path: params.abs_path,
line: None,
}],
};
} else if event.server == mcp_server::SERVER_NAME
&& event.tool == mcp_server::READ_TOOL
&& let Some(params) = event.arguments.take().and_then(|args| {
serde_json::from_value::<ReadToolParams>(args).log_err()
})
{
tool_call = acp::PushToolCallParams {
label: "Read".into(),
icon: acp::Icon::FileSearch,
content: None,
locations: vec![acp::ToolCallLocation {
path: params.abs_path,
line: params.offset,
}],
}
}
let result = delegate.push_tool_call(tool_call).await?;
tool_id_map.borrow_mut().insert(event.call_id, result.id);
}
}
AcpNotification::McpToolCallEnd(event) => {
let acp_call_id = tool_id_map
.borrow_mut()
.remove(&event.call_id)
.context("Missing tool call")?;
let (status, content) = match event.result {
Ok(value) => {
if let Ok(response) =
serde_json::from_value::<context_server::types::CallToolResponse>(value)
{
(
acp::ToolCallStatus::Finished,
mcp_tool_content_to_acp(response.content),
)
} else {
(
acp::ToolCallStatus::Error,
Some(acp::ToolCallContent::Markdown {
markdown: "Failed to parse tool response".to_string(),
}),
)
}
}
Err(error) => (
acp::ToolCallStatus::Error,
Some(acp::ToolCallContent::Markdown { markdown: error }),
),
};
delegate
.update_tool_call(acp::UpdateToolCallParams {
tool_call_id: acp_call_id,
status,
content,
})
.await?;
}
AcpNotification::ExecCommandBegin(event) => {
if let Some(requested_tool_id) = requested_call_id.take() {
tool_id_map
.borrow_mut()
.insert(event.call_id, requested_tool_id);
} else {
let inner_command = strip_bash_lc_and_escape(&event.command);
let result = delegate
.push_tool_call(acp::PushToolCallParams {
label: format!("`{}`", inner_command),
icon: acp::Icon::Terminal,
content: None,
locations: vec![],
})
.await?;
tool_id_map.borrow_mut().insert(event.call_id, result.id);
}
}
AcpNotification::ExecCommandEnd(event) => {
let acp_call_id = tool_id_map
.borrow_mut()
.remove(&event.call_id)
.context("Missing tool call")?;
let mut content = String::new();
if !event.stdout.is_empty() {
use std::fmt::Write;
writeln!(
&mut content,
"### Output\n\n{}",
md_codeblock("", &event.stdout)
)
.unwrap();
}
if !event.stdout.is_empty() && !event.stderr.is_empty() {
use std::fmt::Write;
writeln!(&mut content).unwrap();
}
if !event.stderr.is_empty() {
use std::fmt::Write;
writeln!(
&mut content,
"### Error\n\n{}",
md_codeblock("", &event.stderr)
)
.unwrap();
}
let success = event.exit_code == 0;
if !success {
use std::fmt::Write;
writeln!(&mut content, "\nExit code: `{}`", event.exit_code).unwrap();
}
delegate
.update_tool_call(acp::UpdateToolCallParams {
tool_call_id: acp_call_id,
status: if success {
acp::ToolCallStatus::Finished
} else {
acp::ToolCallStatus::Error
},
content: Some(acp::ToolCallContent::Markdown { markdown: content }),
})
.await?;
}
AcpNotification::Other => {}
}
Ok(())
}
}
/// todo! use types from h2a crate when we have one
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub(crate) struct CodexToolCallParam {
pub prompt: String,
pub cwd: PathBuf,
pub config: Option<CodexConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct CodexToolCallReplyParam {
pub session_id: Uuid,
pub prompt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CodexEvent {
pub msg: AcpNotification,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AcpNotification {
SessionConfigured(SessionConfiguredEvent),
AgentMessage(AgentMessageEvent),
AgentReasoning(AgentReasoningEvent),
McpToolCallBegin(McpToolCallBeginEvent),
McpToolCallEnd(McpToolCallEndEvent),
ExecCommandBegin(ExecCommandBeginEvent),
ExecCommandEnd(ExecCommandEndEvent),
#[serde(other)]
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentMessageEvent {
pub message: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentReasoningEvent {
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolCallBeginEvent {
pub call_id: String,
pub server: String,
pub tool: String,
pub arguments: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolCallEndEvent {
pub call_id: String,
pub result: Result<serde_json::Value, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecCommandBeginEvent {
pub call_id: String,
pub command: Vec<String>,
pub cwd: PathBuf,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecCommandEndEvent {
pub call_id: String,
pub stdout: String,
pub stderr: String,
pub exit_code: i32,
}
#[derive(Debug, Default, Clone, Deserialize, Serialize)]
pub struct SessionConfiguredEvent {
pub session_id: Uuid,
}
// Helper functions
fn md_codeblock(lang: &str, content: &str) -> String {
if content.ends_with('\n') {
format!("```{}\n{}```", lang, content)
} else {
format!("```{}\n{}\n```", lang, content)
}
}
fn strip_bash_lc_and_escape(command: &[String]) -> String {
match command {
// exactly three items
[first, second, third]
// first two must be "bash", "-lc"
if first == "bash" && second == "-lc" =>
{
third.clone()
}
_ => escape_command(command),
}
}
fn escape_command(command: &[String]) -> String {
shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
}
fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::ToolCallContent> {
let mut content = String::new();
for chunk in chunks {
match chunk {
ToolResponseContent::Text { text } => content.push_str(&text),
ToolResponseContent::Image { .. } => {
// todo!
}
ToolResponseContent::Audio { .. } => {
// todo!
}
ToolResponseContent::Resource { .. } => {
// todo!
}
}
}
if !content.is_empty() {
Some(acp::ToolCallContent::Markdown { markdown: content })
} else {
None
}
}
pub struct CodexApproval;
impl context_server::types::Request for CodexApproval {
type Params = CodexElicitation;
type Response = CodexApprovalResponse;
const METHOD: &'static str = "elicitation/create";
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ExecApprovalRequest {
// These fields are required so that `params`
// conforms to ElicitRequestParams.
pub message: String,
// #[serde(rename = "requestedSchema")]
// pub requested_schema: ElicitRequestParamsRequestedSchema,
// // These are additional fields the client can use to
// // correlate the request with the codex tool call.
pub codex_mcp_tool_call_id: String,
// pub codex_event_id: String,
pub codex_command: Vec<String>,
pub codex_cwd: PathBuf,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PatchApprovalRequest {
pub message: String,
// #[serde(rename = "requestedSchema")]
// pub requested_schema: ElicitRequestParamsRequestedSchema,
pub codex_mcp_tool_call_id: String,
pub codex_event_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub codex_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub codex_grant_root: Option<PathBuf>,
pub codex_changes: HashMap<PathBuf, FileChange>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "codex_elicitation", rename_all = "kebab-case")]
pub enum CodexElicitation {
ExecApproval(ExecApprovalRequest),
PatchApproval(PatchApprovalRequest),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum FileChange {
Add {
content: String,
},
Delete,
Update {
unified_diff: String,
move_path: Option<PathBuf>,
},
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CodexApprovalResponse {
pub decision: ReviewDecision,
}
/// User's decision in response to an ExecApprovalRequest.
#[derive(Debug, Default, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ReviewDecision {
/// User has approved this command and the agent should execute it.
Approved,
/// User has approved this command and wants to automatically approve any
/// future identical instances (`command` and `cwd` match exactly) for the
/// remainder of the session.
ApprovedForSession,
/// User has denied this command and the agent should not execute it, but
/// it should continue the session and try something else.
#[default]
Denied,
/// User has denied this command and the agent should not do anything until
/// the user's next command.
Abort,
}
#[cfg(test)]
pub mod tests {
use super::*;
crate::common_e2e_tests!(Codex);
pub fn local_command() -> AgentServerCommand {
let cli_path =
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../codex/code-rs/target/debug/codex");
AgentServerCommand {
path: cli_path,
args: vec!["mcp".into()],
env: None,
}
}
}

View file

@ -350,6 +350,9 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
claude: Some(AgentServerSettings {
command: crate::claude::tests::local_command(),
}),
codex: Some(AgentServerSettings {
command: crate::codex::tests::local_command(),
}),
gemini: Some(AgentServerSettings {
command: crate::gemini::tests::local_command(),
}),

View file

@ -1,29 +1,23 @@
use std::{cell::RefCell, rc::Rc};
use std::{cell::RefCell, path::PathBuf, rc::Rc};
use acp_thread::AcpClientDelegate;
use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams};
use anyhow::{Context, Result};
use collections::HashMap;
use context_server::{
listener::McpServer,
types::{
CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
ToolResponseContent, ToolsCapabilities, requests,
},
use context_server::types::{
CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
ToolResponseContent, ToolsCapabilities, requests,
};
use gpui::{App, AsyncApp, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use util::debug_panic;
use crate::claude::{
McpServerConfig,
tools::{ClaudeTool, EditToolParams, ReadToolParams},
};
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
pub struct ClaudeMcpServer {
server: McpServer,
pub struct ZedMcpServer {
server: context_server::listener::McpServer,
}
pub const SERVER_NAME: &str = "zed";
@ -52,15 +46,36 @@ enum PermissionToolBehavior {
Deny,
}
impl ClaudeMcpServer {
#[derive(Clone)]
pub struct EnabledTools {
pub read: bool,
pub edit: bool,
pub permission: bool,
}
impl Default for EnabledTools {
fn default() -> Self {
EnabledTools {
read: true,
edit: true,
permission: true,
}
}
}
impl ZedMcpServer {
pub async fn new(
delegate: watch::Receiver<Option<AcpClientDelegate>>,
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
enabled_tools: EnabledTools,
cx: &AsyncApp,
) -> Result<Self> {
let mut mcp_server = McpServer::new(cx).await?;
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
mcp_server.handle_request::<requests::ListTools>({
let enabled_tools = enabled_tools.clone();
move |_, cx| Self::handle_list_tools(enabled_tools.clone(), cx)
});
mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx)
});
@ -70,9 +85,7 @@ impl ClaudeMcpServer {
pub fn server_config(&self) -> Result<McpServerConfig> {
let zed_path = std::env::current_exe()
.context("finding current executable path for use in mcp_server")?
.to_string_lossy()
.to_string();
.context("finding current executable path for use in mcp_server")?;
Ok(McpServerConfig {
command: zed_path,
@ -107,17 +120,17 @@ impl ClaudeMcpServer {
})
}
fn handle_list_tools(_: (), cx: &App) -> Task<Result<ListToolsResponse>> {
fn handle_list_tools(enabled: EnabledTools, cx: &App) -> Task<Result<ListToolsResponse>> {
cx.foreground_executor().spawn(async move {
Ok(ListToolsResponse {
tools: vec![
Tool {
tools: [
enabled.permission.then(|| Tool {
name: PERMISSION_TOOL.into(),
input_schema: schemars::schema_for!(PermissionToolParams).into(),
description: None,
annotations: None,
},
Tool {
}),
enabled.read.then(|| Tool {
name: READ_TOOL.into(),
input_schema: schemars::schema_for!(ReadToolParams).into(),
description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()),
@ -130,8 +143,8 @@ impl ClaudeMcpServer {
// true or false seem too strong, let's try a none.
idempotent_hint: None,
}),
},
Tool {
}),
enabled.edit.then(|| Tool {
name: EDIT_TOOL.into(),
input_schema: schemars::schema_for!(EditToolParams).into(),
description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()),
@ -142,8 +155,8 @@ impl ClaudeMcpServer {
open_world_hint: Some(false),
idempotent_hint: Some(false),
}),
},
],
}),
].into_iter().flatten().collect(),
next_cursor: None,
meta: None,
})
@ -294,3 +307,18 @@ impl ClaudeMcpServer {
})
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct McpConfig {
pub mcp_servers: HashMap<String, McpServerConfig>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct McpServerConfig {
pub command: PathBuf,
pub args: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>,
}

View file

@ -13,6 +13,7 @@ pub fn init(cx: &mut App) {
pub struct AllAgentServersSettings {
pub gemini: Option<AgentServerSettings>,
pub claude: Option<AgentServerSettings>,
pub codex: Option<AgentServerSettings>,
}
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
@ -29,13 +30,21 @@ impl settings::Settings for AllAgentServersSettings {
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
let mut settings = AllAgentServersSettings::default();
for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
for AllAgentServersSettings {
gemini,
claude,
codex,
} in sources.defaults_and_customizations()
{
if gemini.is_some() {
settings.gemini = gemini.clone();
}
if claude.is_some() {
settings.claude = claude.clone();
}
if codex.is_some() {
settings.codex = codex.clone();
}
}
Ok(settings)

View file

@ -1126,21 +1126,27 @@ impl AcpThreadView {
))
.into_any(),
ToolCallConfirmation::Execute {
command,
command: _,
root_command,
description,
} => confirmation_container
.child(v_flex().px_2().pb_1p5().child(command.clone()).children(
description.clone().map(|description| {
self.render_markdown(description, default_markdown_style(false, window, cx))
.child(
v_flex()
.px_2()
.pb_1p5()
.children(description.clone().map(|description| {
self.render_markdown(
description,
default_markdown_style(false, window, cx),
)
.on_url_click({
let workspace = self.workspace.clone();
move |text, window, cx| {
Self::open_link(text, &workspace, window, cx);
}
})
}),
))
})),
)
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[AlwaysAllowOption {

View file

@ -1980,6 +1980,13 @@ impl AgentPanel {
);
}),
)
.action(
"New Codex Thread",
NewExternalAgentThread {
agent: Some(crate::ExternalAgent::Codex),
}
.boxed_clone(),
)
});
menu
}))

View file

@ -150,6 +150,7 @@ enum ExternalAgent {
#[default]
Gemini,
ClaudeCode,
Codex,
}
impl ExternalAgent {
@ -157,6 +158,7 @@ impl ExternalAgent {
match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
ExternalAgent::Codex => Rc::new(agent_servers::Codex),
}
}
}

View file

@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use collections::HashMap;
use futures::{FutureExt, StreamExt, channel::oneshot, select};
use futures::{FutureExt, StreamExt, channel::oneshot, future, select};
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
use parking_lot::Mutex;
use postage::barrier;
@ -10,15 +10,19 @@ use smol::channel;
use std::{
fmt,
path::PathBuf,
pin::pin,
sync::{
Arc,
atomic::{AtomicI32, Ordering::SeqCst},
},
time::{Duration, Instant},
};
use util::TryFutureExt;
use util::{ResultExt, TryFutureExt};
use crate::transport::{StdioTransport, Transport};
use crate::{
transport::{StdioTransport, Transport},
types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
};
const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
@ -32,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
@ -46,6 +51,7 @@ pub(crate) struct Client {
outbound_tx: channel::Sender<String>,
name: Arc<str>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
#[allow(clippy::type_complexity)]
#[allow(dead_code)]
@ -78,6 +84,15 @@ pub struct Request<'a, T> {
pub params: T,
}
#[derive(Serialize, Deserialize)]
pub struct AnyRequest<'a> {
pub jsonrpc: &'a str,
pub id: RequestId,
pub method: &'a str,
#[serde(skip_serializing_if = "is_null_value")]
pub params: Option<&'a RawValue>,
}
#[derive(Serialize, Deserialize)]
struct AnyResponse<'a> {
jsonrpc: &'a str,
@ -176,15 +191,23 @@ impl Client {
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
let receive_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
let request_handlers = request_handlers.clone();
let transport = transport.clone();
async move |cx| {
Self::handle_input(transport, notification_handlers, response_handlers, cx)
.log_err()
.await
Self::handle_input(
transport,
notification_handlers,
request_handlers,
response_handlers,
cx,
)
.log_err()
.await
}
});
let receive_err_task = cx.spawn({
@ -211,6 +234,7 @@ impl Client {
server_id,
notification_handlers,
response_handlers,
request_handlers,
name: server_name,
next_id: Default::default(),
outbound_tx,
@ -230,13 +254,24 @@ impl Client {
async fn handle_input(
transport: Arc<dyn Transport>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
cx: &mut AsyncApp,
) -> anyhow::Result<()> {
let mut receiver = transport.receive();
while let Some(message) = receiver.next().await {
if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
log::trace!("recv: {}", &message);
if let Ok(request) = serde_json::from_str::<AnyRequest>(&message) {
let mut request_handlers = request_handlers.lock();
if let Some(handler) = request_handlers.get_mut(request.method) {
handler(
request.id,
request.params.unwrap_or(RawValue::NULL),
cx.clone(),
);
}
} else if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
if let Some(handlers) = response_handlers.lock().as_mut() {
if let Some(handler) = handlers.remove(&response.id) {
handler(Ok(message.to_string()));
@ -247,6 +282,8 @@ impl Client {
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
handler(notification.params.unwrap_or(Value::Null), cx.clone());
}
} else {
log::error!("Unhandled JSON from context_server: {}", message);
}
}
@ -294,6 +331,17 @@ impl Client {
&self,
method: &str,
params: impl Serialize,
) -> Result<T> {
self.request_with(method, params, None, Some(REQUEST_TIMEOUT))
.await
}
pub async fn request_with<T: DeserializeOwned>(
&self,
method: &str,
params: impl Serialize,
cancel_rx: Option<oneshot::Receiver<()>>,
timeout: Option<Duration>,
) -> Result<T> {
let id = self.next_id.fetch_add(1, SeqCst);
let request = serde_json::to_string(&Request {
@ -329,7 +377,25 @@ impl Client {
handle_response?;
send?;
let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
let mut timeout_fut = pin!(
if let Some(timeout) = timeout {
future::Either::Left(executor.timer(timeout))
} else {
future::Either::Right(future::pending())
}
.fuse()
);
let mut cancel_fut = pin!(
match cancel_rx {
Some(rx) => future::Either::Left(async {
rx.await.log_err();
}),
None => future::Either::Right(future::pending()),
}
.fuse()
);
select! {
response = rx.fuse() => {
let elapsed = started.elapsed();
@ -348,8 +414,18 @@ impl Client {
Err(_) => anyhow::bail!("cancelled")
}
}
_ = timeout => {
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
_ = cancel_fut => {
self.notify(
Cancelled::METHOD,
ClientNotification::Cancelled(CancelledParams {
request_id: RequestId::Int(id),
reason: None
})
).log_err();
anyhow::bail!("Request cancelled")
}
_ = timeout_fut => {
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout);
anyhow::bail!("Context server request timeout");
}
}
@ -377,6 +453,79 @@ impl Client {
.lock()
.insert(method, Box::new(f));
}
pub fn on_request<R: crate::types::Request, F>(&self, mut f: F)
where
F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task<Result<R::Response>>,
{
let outbound_tx = self.outbound_tx.clone();
self.request_handlers.lock().insert(
R::METHOD,
Box::new(move |id, json, cx| {
let outbound_tx = outbound_tx.clone();
match serde_json::from_str(json.get()) {
Ok(req) => {
let task = f(req, cx.clone());
cx.foreground_executor()
.spawn(async move {
match task.await {
Ok(res) => {
outbound_tx
.send(
serde_json::to_string(&Response {
jsonrpc: JSON_RPC_VERSION,
id,
value: CspResult::Ok(Some(res)),
})
.unwrap(),
)
.await
.ok();
}
Err(e) => {
outbound_tx
.send(
serde_json::to_string(&Response {
jsonrpc: JSON_RPC_VERSION,
id,
value: CspResult::<()>::Error(Some(Error {
code: -1, // todo!()
message: format!("{e}"),
})),
})
.unwrap(),
)
.await
.ok();
}
}
})
.detach();
}
Err(e) => {
cx.foreground_executor()
.spawn(async move {
outbound_tx
.send(
serde_json::to_string(&Response {
jsonrpc: JSON_RPC_VERSION,
id,
value: CspResult::<()>::Error(Some(Error {
code: -1, // todo!()
message: format!("{e}"),
})),
})
.unwrap(),
)
.await
.ok();
})
.detach();
}
}
}),
);
}
}
impl fmt::Display for ContextServerId {

View file

@ -5,7 +5,12 @@
//! read/write messages and the types from types.rs for serialization/deserialization
//! of messages.
use std::time::Duration;
use anyhow::Result;
use futures::channel::oneshot;
use gpui::{AsyncApp, Task};
use serde_json::Value;
use crate::client::Client;
use crate::types::{self, Notification, Request};
@ -95,7 +100,32 @@ impl InitializedContextServerProtocol {
self.inner.request(T::METHOD, params).await
}
pub async fn request_with<T: Request>(
&self,
params: T::Params,
cancel_rx: Option<oneshot::Receiver<()>>,
timeout: Option<Duration>,
) -> Result<T::Response> {
self.inner
.request_with(T::METHOD, params, cancel_rx, timeout)
.await
}
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
self.inner.notify(T::METHOD, params)
}
pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncApp),
{
self.inner.on_notification(method, f);
}
pub fn on_request<R: crate::types::Request, F>(&self, f: F)
where
F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task<Result<R::Response>>,
{
self.inner.on_request::<R, F>(f);
}
}

View file

@ -3,6 +3,8 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use url::Url;
use crate::client::RequestId;
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
pub const VERSION_2024_11_05: &str = "2024-11-05";
@ -100,6 +102,7 @@ pub mod notifications {
notification!("notifications/initialized", Initialized, ());
notification!("notifications/progress", Progress, ProgressParams);
notification!("notifications/message", Message, MessageParams);
notification!("notifications/cancelled", Cancelled, CancelledParams);
notification!(
"notifications/resources/updated",
ResourcesUpdated,
@ -617,11 +620,14 @@ pub enum ClientNotification {
Initialized,
Progress(ProgressParams),
RootsListChanged,
Cancelled {
request_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
reason: Option<String>,
},
Cancelled(CancelledParams),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CancelledParams {
pub request_id: RequestId,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]

View file

@ -160,6 +160,7 @@ zed_actions.workspace = true
zeta.workspace = true
zlog.workspace = true
zlog_settings.workspace = true
erased-serde = "0.4.6"
[target.'cfg(target_os = "windows")'.dependencies]
windows.workspace = true