Compare commits
20 commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
2a00a53fcf | ||
![]() |
1e5625c4b4 | ||
![]() |
20c0a06485 | ||
![]() |
598a8180b5 | ||
![]() |
07a08f0972 | ||
![]() |
cedd6aa704 | ||
![]() |
966d29dcd9 | ||
![]() |
cede9d757a | ||
![]() |
c0c698b883 | ||
![]() |
0fa7d58a3e | ||
![]() |
9b91445967 | ||
![]() |
480adade63 | ||
![]() |
47dec0df99 | ||
![]() |
f20edf1b50 | ||
![]() |
03b94f5831 | ||
![]() |
e7298c0736 | ||
![]() |
a822711e99 | ||
![]() |
4b1ace9a54 | ||
![]() |
769d6dc632 | ||
![]() |
f56910556f |
15 changed files with 1114 additions and 74 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -159,6 +159,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"shlex",
|
||||
"smol",
|
||||
"strum 0.27.1",
|
||||
"tempfile",
|
||||
|
@ -20204,6 +20205,7 @@ dependencies = [
|
|||
"diagnostics",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"erased-serde",
|
||||
"extension",
|
||||
"extension_host",
|
||||
"extensions_ui",
|
||||
|
|
|
@ -40,6 +40,7 @@ util.workspace = true
|
|||
uuid.workspace = true
|
||||
watch.workspace = true
|
||||
which.workspace = true
|
||||
shlex.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
mod claude;
|
||||
mod codex;
|
||||
mod gemini;
|
||||
mod mcp_server;
|
||||
mod settings;
|
||||
mod stdio_agent_server;
|
||||
|
||||
|
@ -7,6 +9,7 @@ mod stdio_agent_server;
|
|||
mod e2e_tests;
|
||||
|
||||
pub use claude::*;
|
||||
pub use codex::*;
|
||||
pub use gemini::*;
|
||||
pub use settings::*;
|
||||
pub use stdio_agent_server::*;
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
mod mcp_server;
|
||||
mod tools;
|
||||
pub mod tools;
|
||||
|
||||
use collections::HashMap;
|
||||
use project::Project;
|
||||
|
@ -30,8 +29,8 @@ use gpui::{App, AppContext, Entity, Task};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::claude::mcp_server::ClaudeMcpServer;
|
||||
use crate::claude::tools::ClaudeTool;
|
||||
use crate::mcp_server::{self, McpConfig, ZedMcpServer};
|
||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
|
||||
use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
|
||||
|
||||
|
@ -72,12 +71,13 @@ impl AgentServer for ClaudeCode {
|
|||
let (mut delegate_tx, delegate_rx) = watch::channel(None);
|
||||
let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
|
||||
let permission_mcp_server =
|
||||
ZedMcpServer::new(delegate_rx, tool_id_map.clone(), Default::default(), cx).await?;
|
||||
|
||||
let mut mcp_servers = HashMap::default();
|
||||
mcp_servers.insert(
|
||||
mcp_server::SERVER_NAME.to_string(),
|
||||
mcp_server.server_config()?,
|
||||
crate::mcp_server::SERVER_NAME.to_string(),
|
||||
permission_mcp_server.server_config()?,
|
||||
);
|
||||
let mcp_config = McpConfig { mcp_servers };
|
||||
|
||||
|
@ -187,7 +187,7 @@ impl AgentServer for ClaudeCode {
|
|||
_mcp_server: None,
|
||||
};
|
||||
|
||||
connection._mcp_server = Some(mcp_server);
|
||||
connection._mcp_server = Some(permission_mcp_server);
|
||||
acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
|
||||
})
|
||||
})
|
||||
|
@ -333,7 +333,7 @@ struct ClaudeAgentConnection {
|
|||
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
||||
cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>,
|
||||
_mcp_server: Option<ClaudeMcpServer>,
|
||||
_mcp_server: Option<ZedMcpServer>,
|
||||
_handler_task: Task<()>,
|
||||
}
|
||||
|
||||
|
@ -661,21 +661,6 @@ enum PermissionMode {
|
|||
Plan,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct McpConfig {
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct McpServerConfig {
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
|
|
808
crates/agent_servers/src/codex.rs
Normal file
808
crates/agent_servers/src/codex.rs
Normal file
|
@ -0,0 +1,808 @@
|
|||
use collections::HashMap;
|
||||
use context_server::types::requests::CallTool;
|
||||
use context_server::types::{CallToolParams, ToolResponseContent};
|
||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt;
|
||||
use std::cell::RefCell;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use agentic_coding_protocol::{
|
||||
self as acp, AnyAgentRequest, AnyAgentResult, Client as _, ProtocolVersion,
|
||||
};
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use futures::future::LocalBoxFuture;
|
||||
use futures::{FutureExt, SinkExt as _};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::mcp_server::{self, McpServerConfig, ZedMcpServer};
|
||||
use crate::tools::{EditToolParams, ReadToolParams};
|
||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
|
||||
use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Codex;
|
||||
|
||||
impl AgentServer for Codex {
|
||||
fn name(&self) -> &'static str {
|
||||
"Codex"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
self.name()
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
""
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiOpenAi
|
||||
}
|
||||
|
||||
fn supports_always_allow(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
&self,
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let project = project.clone();
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
let title = self.name().into();
|
||||
cx.spawn(async move |cx| {
|
||||
let (mut delegate_tx, delegate_rx) = watch::channel(None);
|
||||
let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let zed_mcp_server = ZedMcpServer::new(
|
||||
delegate_rx,
|
||||
tool_id_map.clone(),
|
||||
mcp_server::EnabledTools {
|
||||
permission: false,
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
})?;
|
||||
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find codex binary");
|
||||
};
|
||||
|
||||
let codex_mcp_client: Arc<ContextServer> = ContextServer::stdio(
|
||||
ContextServerId("codex-mcp-server".into()),
|
||||
ContextServerCommand {
|
||||
path: command.path,
|
||||
args: command.args,
|
||||
env: command.env,
|
||||
},
|
||||
)
|
||||
.into();
|
||||
|
||||
ContextServer::start(codex_mcp_client.clone(), cx).await?;
|
||||
// todo! stop
|
||||
|
||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||
let (request_tx, mut request_rx) = mpsc::unbounded();
|
||||
|
||||
let client = codex_mcp_client
|
||||
.client()
|
||||
.context("Failed to subscribe to server")?;
|
||||
|
||||
client.on_notification("codex/event", {
|
||||
move |event, cx| {
|
||||
let mut notification_tx = notification_tx.clone();
|
||||
cx.background_spawn(async move {
|
||||
log::trace!("Notification: {:?}", serde_json::to_string_pretty(&event));
|
||||
if let Some(event) = serde_json::from_value::<CodexEvent>(event).log_err() {
|
||||
notification_tx.send(event.msg).await.log_err();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
});
|
||||
|
||||
client.on_request::<CodexApproval, _>({
|
||||
move |elicitation, cx| {
|
||||
let (tx, rx) = oneshot::channel::<Result<CodexApprovalResponse>>();
|
||||
let mut request_tx = request_tx.clone();
|
||||
cx.background_spawn(async move {
|
||||
log::trace!("Elicitation: {:?}", elicitation);
|
||||
request_tx.send((elicitation, tx)).await?;
|
||||
rx.await?
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
let requested_call_id = Rc::new(RefCell::new(None));
|
||||
let session_id = Rc::new(RefCell::new(None));
|
||||
|
||||
cx.new(|cx| {
|
||||
let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
|
||||
delegate_tx.send(Some(delegate.clone())).log_err();
|
||||
|
||||
let handler_task = cx.spawn({
|
||||
let delegate = delegate.clone();
|
||||
let tool_id_map = tool_id_map.clone();
|
||||
let requested_call_id = requested_call_id.clone();
|
||||
let session_id = session_id.clone();
|
||||
async move |_, _cx| {
|
||||
while let Some(notification) = notification_rx.next().await {
|
||||
CodexAgentConnection::handle_acp_notification(
|
||||
&delegate,
|
||||
notification,
|
||||
&session_id,
|
||||
&tool_id_map,
|
||||
&requested_call_id,
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let request_task = cx.spawn({
|
||||
let delegate = delegate.clone();
|
||||
async move |_, _cx| {
|
||||
while let Some((elicitation, respond)) = request_rx.next().await {
|
||||
if let Some((id, decision)) =
|
||||
CodexAgentConnection::handle_elicitation(&delegate, elicitation)
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
requested_call_id.replace(Some(id));
|
||||
|
||||
respond
|
||||
.send(Ok(CodexApprovalResponse { decision }))
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let connection = CodexAgentConnection {
|
||||
root_dir,
|
||||
codex_mcp: codex_mcp_client,
|
||||
cancel_request_tx: Default::default(),
|
||||
session_id,
|
||||
zed_mcp_server,
|
||||
_handler_task: handler_task,
|
||||
_request_task: request_task,
|
||||
};
|
||||
|
||||
acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct CodexAgentConnection {
|
||||
codex_mcp: Arc<context_server::ContextServer>,
|
||||
root_dir: PathBuf,
|
||||
cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
|
||||
session_id: Rc<RefCell<Option<Uuid>>>,
|
||||
zed_mcp_server: ZedMcpServer,
|
||||
_handler_task: Task<()>,
|
||||
_request_task: Task<()>,
|
||||
}
|
||||
|
||||
impl AgentConnection for CodexAgentConnection {
|
||||
/// Send a request to the agent and wait for a response.
|
||||
fn request_any(
|
||||
&self,
|
||||
params: AnyAgentRequest,
|
||||
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
|
||||
let client = self.codex_mcp.client();
|
||||
let root_dir = self.root_dir.clone();
|
||||
let cancel_request_tx = self.cancel_request_tx.clone();
|
||||
let mcp_config = self.zed_mcp_server.server_config();
|
||||
let session_id = self.session_id.clone();
|
||||
async move {
|
||||
let client = client.context("Codex MCP server is not initialized")?;
|
||||
|
||||
match params {
|
||||
// todo: consider sending an empty request so we get the init response?
|
||||
AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
|
||||
acp::InitializeResponse {
|
||||
is_authenticated: true,
|
||||
protocol_version: ProtocolVersion::latest(),
|
||||
},
|
||||
)),
|
||||
AnyAgentRequest::AuthenticateParams(_) => {
|
||||
Err(anyhow!("Authentication not supported"))
|
||||
}
|
||||
AnyAgentRequest::SendUserMessageParams(message) => {
|
||||
let (new_cancel_tx, cancel_rx) = oneshot::channel();
|
||||
cancel_request_tx.borrow_mut().replace(new_cancel_tx);
|
||||
|
||||
let prompt = message
|
||||
.chunks
|
||||
.into_iter()
|
||||
.filter_map(|chunk| match chunk {
|
||||
acp::UserMessageChunk::Text { text } => Some(text),
|
||||
acp::UserMessageChunk::Path { .. } => {
|
||||
// todo!
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let params = if let Some(session_id) = *session_id.borrow() {
|
||||
CallToolParams {
|
||||
name: "codex-reply".into(),
|
||||
arguments: Some(serde_json::to_value(CodexToolCallReplyParam {
|
||||
prompt,
|
||||
session_id,
|
||||
})?),
|
||||
meta: None,
|
||||
}
|
||||
} else {
|
||||
CallToolParams {
|
||||
name: "codex".into(),
|
||||
arguments: Some(serde_json::to_value(CodexToolCallParam {
|
||||
prompt,
|
||||
cwd: root_dir,
|
||||
config: Some(CodexConfig {
|
||||
mcp_servers: Some(
|
||||
mcp_config
|
||||
.into_iter()
|
||||
.map(|config| {
|
||||
(mcp_server::SERVER_NAME.to_string(), config)
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
}),
|
||||
})?),
|
||||
meta: None,
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.request_with::<CallTool>(params, Some(cancel_rx), None)
|
||||
.await?;
|
||||
|
||||
Ok(AnyAgentResult::SendUserMessageResponse(
|
||||
acp::SendUserMessageResponse,
|
||||
))
|
||||
}
|
||||
AnyAgentRequest::CancelSendMessageParams(_) => {
|
||||
if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() {
|
||||
if let Some(cancel_tx) = borrow.take() {
|
||||
cancel_tx.send(()).ok();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AnyAgentResult::CancelSendMessageResponse(
|
||||
acp::CancelSendMessageResponse,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CodexConfig {
|
||||
mcp_servers: Option<HashMap<String, McpServerConfig>>,
|
||||
}
|
||||
|
||||
impl CodexAgentConnection {
|
||||
async fn handle_elicitation(
|
||||
delegate: &AcpClientDelegate,
|
||||
elicitation: CodexElicitation,
|
||||
) -> Result<(acp::ToolCallId, ReviewDecision)> {
|
||||
let confirmation = match elicitation {
|
||||
CodexElicitation::ExecApproval(exec) => {
|
||||
let inner_command = strip_bash_lc_and_escape(&exec.codex_command);
|
||||
|
||||
acp::RequestToolCallConfirmationParams {
|
||||
tool_call: acp::PushToolCallParams {
|
||||
label: format!("`{inner_command}`"),
|
||||
icon: acp::Icon::Terminal,
|
||||
content: None,
|
||||
locations: vec![],
|
||||
},
|
||||
confirmation: acp::ToolCallConfirmation::Execute {
|
||||
root_command: inner_command
|
||||
.split(" ")
|
||||
.next()
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
command: inner_command,
|
||||
description: Some(exec.message),
|
||||
},
|
||||
}
|
||||
}
|
||||
CodexElicitation::PatchApproval(patch) => {
|
||||
acp::RequestToolCallConfirmationParams {
|
||||
tool_call: acp::PushToolCallParams {
|
||||
label: "Edit".to_string(),
|
||||
icon: acp::Icon::Pencil,
|
||||
content: None, // todo!()
|
||||
locations: patch
|
||||
.codex_changes
|
||||
.keys()
|
||||
.map(|path| acp::ToolCallLocation {
|
||||
path: path.clone(),
|
||||
line: None,
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
confirmation: acp::ToolCallConfirmation::Edit {
|
||||
description: Some(patch.message),
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let response = delegate
|
||||
.request_tool_call_confirmation(confirmation)
|
||||
.await?;
|
||||
|
||||
let decision = match response.outcome {
|
||||
acp::ToolCallConfirmationOutcome::Allow => ReviewDecision::Approved,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllow
|
||||
| acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
|
||||
| acp::ToolCallConfirmationOutcome::AlwaysAllowTool => {
|
||||
ReviewDecision::ApprovedForSession
|
||||
}
|
||||
acp::ToolCallConfirmationOutcome::Reject => ReviewDecision::Denied,
|
||||
acp::ToolCallConfirmationOutcome::Cancel => ReviewDecision::Abort,
|
||||
};
|
||||
|
||||
Ok((response.id, decision))
|
||||
}
|
||||
|
||||
async fn handle_acp_notification(
|
||||
delegate: &AcpClientDelegate,
|
||||
event: AcpNotification,
|
||||
session_id: &Rc<RefCell<Option<Uuid>>>,
|
||||
tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
requested_call_id: &Rc<RefCell<Option<acp::ToolCallId>>>,
|
||||
) -> Result<()> {
|
||||
match event {
|
||||
AcpNotification::SessionConfigured(sesh) => {
|
||||
session_id.replace(Some(sesh.session_id));
|
||||
}
|
||||
AcpNotification::AgentMessage(message) => {
|
||||
delegate
|
||||
.stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
|
||||
chunk: acp::AssistantMessageChunk::Text {
|
||||
text: message.message,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
AcpNotification::AgentReasoning(message) => {
|
||||
delegate
|
||||
.stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
|
||||
chunk: acp::AssistantMessageChunk::Thought {
|
||||
thought: message.text,
|
||||
},
|
||||
})
|
||||
.await?
|
||||
}
|
||||
AcpNotification::McpToolCallBegin(mut event) => {
|
||||
if let Some(requested_tool_id) = requested_call_id.take() {
|
||||
tool_id_map
|
||||
.borrow_mut()
|
||||
.insert(event.call_id, requested_tool_id);
|
||||
} else {
|
||||
let mut tool_call = acp::PushToolCallParams {
|
||||
label: format!("`{}: {}`", event.server, event.tool),
|
||||
icon: acp::Icon::Hammer,
|
||||
content: event.arguments.as_ref().and_then(|args| {
|
||||
Some(acp::ToolCallContent::Markdown {
|
||||
markdown: md_codeblock(
|
||||
"json",
|
||||
&serde_json::to_string_pretty(args).ok()?,
|
||||
),
|
||||
})
|
||||
}),
|
||||
locations: vec![],
|
||||
};
|
||||
|
||||
if event.server == mcp_server::SERVER_NAME
|
||||
&& event.tool == mcp_server::EDIT_TOOL
|
||||
&& let Some(params) = event.arguments.take().and_then(|args| {
|
||||
serde_json::from_value::<EditToolParams>(args).log_err()
|
||||
})
|
||||
{
|
||||
tool_call = acp::PushToolCallParams {
|
||||
label: "Edit".into(),
|
||||
icon: acp::Icon::Pencil,
|
||||
content: Some(acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
path: params.abs_path.clone(),
|
||||
old_text: Some(params.old_text),
|
||||
new_text: params.new_text,
|
||||
},
|
||||
}),
|
||||
locations: vec![acp::ToolCallLocation {
|
||||
path: params.abs_path,
|
||||
line: None,
|
||||
}],
|
||||
};
|
||||
} else if event.server == mcp_server::SERVER_NAME
|
||||
&& event.tool == mcp_server::READ_TOOL
|
||||
&& let Some(params) = event.arguments.take().and_then(|args| {
|
||||
serde_json::from_value::<ReadToolParams>(args).log_err()
|
||||
})
|
||||
{
|
||||
tool_call = acp::PushToolCallParams {
|
||||
label: "Read".into(),
|
||||
icon: acp::Icon::FileSearch,
|
||||
content: None,
|
||||
locations: vec![acp::ToolCallLocation {
|
||||
path: params.abs_path,
|
||||
line: params.offset,
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
let result = delegate.push_tool_call(tool_call).await?;
|
||||
|
||||
tool_id_map.borrow_mut().insert(event.call_id, result.id);
|
||||
}
|
||||
}
|
||||
AcpNotification::McpToolCallEnd(event) => {
|
||||
let acp_call_id = tool_id_map
|
||||
.borrow_mut()
|
||||
.remove(&event.call_id)
|
||||
.context("Missing tool call")?;
|
||||
|
||||
let (status, content) = match event.result {
|
||||
Ok(value) => {
|
||||
if let Ok(response) =
|
||||
serde_json::from_value::<context_server::types::CallToolResponse>(value)
|
||||
{
|
||||
(
|
||||
acp::ToolCallStatus::Finished,
|
||||
mcp_tool_content_to_acp(response.content),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
acp::ToolCallStatus::Error,
|
||||
Some(acp::ToolCallContent::Markdown {
|
||||
markdown: "Failed to parse tool response".to_string(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(error) => (
|
||||
acp::ToolCallStatus::Error,
|
||||
Some(acp::ToolCallContent::Markdown { markdown: error }),
|
||||
),
|
||||
};
|
||||
|
||||
delegate
|
||||
.update_tool_call(acp::UpdateToolCallParams {
|
||||
tool_call_id: acp_call_id,
|
||||
status,
|
||||
content,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
AcpNotification::ExecCommandBegin(event) => {
|
||||
if let Some(requested_tool_id) = requested_call_id.take() {
|
||||
tool_id_map
|
||||
.borrow_mut()
|
||||
.insert(event.call_id, requested_tool_id);
|
||||
} else {
|
||||
let inner_command = strip_bash_lc_and_escape(&event.command);
|
||||
|
||||
let result = delegate
|
||||
.push_tool_call(acp::PushToolCallParams {
|
||||
label: format!("`{}`", inner_command),
|
||||
icon: acp::Icon::Terminal,
|
||||
content: None,
|
||||
locations: vec![],
|
||||
})
|
||||
.await?;
|
||||
|
||||
tool_id_map.borrow_mut().insert(event.call_id, result.id);
|
||||
}
|
||||
}
|
||||
AcpNotification::ExecCommandEnd(event) => {
|
||||
let acp_call_id = tool_id_map
|
||||
.borrow_mut()
|
||||
.remove(&event.call_id)
|
||||
.context("Missing tool call")?;
|
||||
|
||||
let mut content = String::new();
|
||||
if !event.stdout.is_empty() {
|
||||
use std::fmt::Write;
|
||||
writeln!(
|
||||
&mut content,
|
||||
"### Output\n\n{}",
|
||||
md_codeblock("", &event.stdout)
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
if !event.stdout.is_empty() && !event.stderr.is_empty() {
|
||||
use std::fmt::Write;
|
||||
writeln!(&mut content).unwrap();
|
||||
}
|
||||
if !event.stderr.is_empty() {
|
||||
use std::fmt::Write;
|
||||
writeln!(
|
||||
&mut content,
|
||||
"### Error\n\n{}",
|
||||
md_codeblock("", &event.stderr)
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
let success = event.exit_code == 0;
|
||||
if !success {
|
||||
use std::fmt::Write;
|
||||
writeln!(&mut content, "\nExit code: `{}`", event.exit_code).unwrap();
|
||||
}
|
||||
|
||||
delegate
|
||||
.update_tool_call(acp::UpdateToolCallParams {
|
||||
tool_call_id: acp_call_id,
|
||||
status: if success {
|
||||
acp::ToolCallStatus::Finished
|
||||
} else {
|
||||
acp::ToolCallStatus::Error
|
||||
},
|
||||
content: Some(acp::ToolCallContent::Markdown { markdown: content }),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
AcpNotification::Other => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// todo! use types from h2a crate when we have one
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub(crate) struct CodexToolCallParam {
|
||||
pub prompt: String,
|
||||
pub cwd: PathBuf,
|
||||
pub config: Option<CodexConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub(crate) struct CodexToolCallReplyParam {
|
||||
pub session_id: Uuid,
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct CodexEvent {
|
||||
pub msg: AcpNotification,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum AcpNotification {
|
||||
SessionConfigured(SessionConfiguredEvent),
|
||||
AgentMessage(AgentMessageEvent),
|
||||
AgentReasoning(AgentReasoningEvent),
|
||||
McpToolCallBegin(McpToolCallBeginEvent),
|
||||
McpToolCallEnd(McpToolCallEndEvent),
|
||||
ExecCommandBegin(ExecCommandBeginEvent),
|
||||
ExecCommandEnd(ExecCommandEndEvent),
|
||||
#[serde(other)]
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentMessageEvent {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct AgentReasoningEvent {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolCallBeginEvent {
|
||||
pub call_id: String,
|
||||
pub server: String,
|
||||
pub tool: String,
|
||||
pub arguments: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolCallEndEvent {
|
||||
pub call_id: String,
|
||||
pub result: Result<serde_json::Value, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecCommandBeginEvent {
|
||||
pub call_id: String,
|
||||
pub command: Vec<String>,
|
||||
pub cwd: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExecCommandEndEvent {
|
||||
pub call_id: String,
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
pub exit_code: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Deserialize, Serialize)]
|
||||
pub struct SessionConfiguredEvent {
|
||||
pub session_id: Uuid,
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
fn md_codeblock(lang: &str, content: &str) -> String {
|
||||
if content.ends_with('\n') {
|
||||
format!("```{}\n{}```", lang, content)
|
||||
} else {
|
||||
format!("```{}\n{}\n```", lang, content)
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_bash_lc_and_escape(command: &[String]) -> String {
|
||||
match command {
|
||||
// exactly three items
|
||||
[first, second, third]
|
||||
// first two must be "bash", "-lc"
|
||||
if first == "bash" && second == "-lc" =>
|
||||
{
|
||||
third.clone()
|
||||
}
|
||||
_ => escape_command(command),
|
||||
}
|
||||
}
|
||||
|
||||
fn escape_command(command: &[String]) -> String {
|
||||
shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
|
||||
}
|
||||
|
||||
fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::ToolCallContent> {
|
||||
let mut content = String::new();
|
||||
|
||||
for chunk in chunks {
|
||||
match chunk {
|
||||
ToolResponseContent::Text { text } => content.push_str(&text),
|
||||
ToolResponseContent::Image { .. } => {
|
||||
// todo!
|
||||
}
|
||||
ToolResponseContent::Audio { .. } => {
|
||||
// todo!
|
||||
}
|
||||
ToolResponseContent::Resource { .. } => {
|
||||
// todo!
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !content.is_empty() {
|
||||
Some(acp::ToolCallContent::Markdown { markdown: content })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CodexApproval;
|
||||
impl context_server::types::Request for CodexApproval {
|
||||
type Params = CodexElicitation;
|
||||
type Response = CodexApprovalResponse;
|
||||
const METHOD: &'static str = "elicitation/create";
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ExecApprovalRequest {
|
||||
// These fields are required so that `params`
|
||||
// conforms to ElicitRequestParams.
|
||||
pub message: String,
|
||||
// #[serde(rename = "requestedSchema")]
|
||||
// pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
|
||||
// // These are additional fields the client can use to
|
||||
// // correlate the request with the codex tool call.
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
// pub codex_event_id: String,
|
||||
pub codex_command: Vec<String>,
|
||||
pub codex_cwd: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PatchApprovalRequest {
|
||||
pub message: String,
|
||||
// #[serde(rename = "requestedSchema")]
|
||||
// pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codex_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codex_grant_root: Option<PathBuf>,
|
||||
pub codex_changes: HashMap<PathBuf, FileChange>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "codex_elicitation", rename_all = "kebab-case")]
|
||||
pub enum CodexElicitation {
|
||||
ExecApproval(ExecApprovalRequest),
|
||||
PatchApproval(PatchApprovalRequest),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum FileChange {
|
||||
Add {
|
||||
content: String,
|
||||
},
|
||||
Delete,
|
||||
Update {
|
||||
unified_diff: String,
|
||||
move_path: Option<PathBuf>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CodexApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
/// User's decision in response to an ExecApprovalRequest.
|
||||
#[derive(Debug, Default, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ReviewDecision {
|
||||
/// User has approved this command and the agent should execute it.
|
||||
Approved,
|
||||
|
||||
/// User has approved this command and wants to automatically approve any
|
||||
/// future identical instances (`command` and `cwd` match exactly) for the
|
||||
/// remainder of the session.
|
||||
ApprovedForSession,
|
||||
|
||||
/// User has denied this command and the agent should not execute it, but
|
||||
/// it should continue the session and try something else.
|
||||
#[default]
|
||||
Denied,
|
||||
|
||||
/// User has denied this command and the agent should not do anything until
|
||||
/// the user's next command.
|
||||
Abort,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
|
||||
crate::common_e2e_tests!(Codex);
|
||||
|
||||
pub fn local_command() -> AgentServerCommand {
|
||||
let cli_path =
|
||||
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../codex/code-rs/target/debug/codex");
|
||||
|
||||
AgentServerCommand {
|
||||
path: cli_path,
|
||||
args: vec!["mcp".into()],
|
||||
env: None,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -350,6 +350,9 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
|||
claude: Some(AgentServerSettings {
|
||||
command: crate::claude::tests::local_command(),
|
||||
}),
|
||||
codex: Some(AgentServerSettings {
|
||||
command: crate::codex::tests::local_command(),
|
||||
}),
|
||||
gemini: Some(AgentServerSettings {
|
||||
command: crate::gemini::tests::local_command(),
|
||||
}),
|
||||
|
|
|
@ -1,29 +1,23 @@
|
|||
use std::{cell::RefCell, rc::Rc};
|
||||
use std::{cell::RefCell, path::PathBuf, rc::Rc};
|
||||
|
||||
use acp_thread::AcpClientDelegate;
|
||||
use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams};
|
||||
use anyhow::{Context, Result};
|
||||
use collections::HashMap;
|
||||
use context_server::{
|
||||
listener::McpServer,
|
||||
types::{
|
||||
CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
|
||||
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
|
||||
ToolResponseContent, ToolsCapabilities, requests,
|
||||
},
|
||||
use context_server::types::{
|
||||
CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
|
||||
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
|
||||
ToolResponseContent, ToolsCapabilities, requests,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Task};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::debug_panic;
|
||||
|
||||
use crate::claude::{
|
||||
McpServerConfig,
|
||||
tools::{ClaudeTool, EditToolParams, ReadToolParams},
|
||||
};
|
||||
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
|
||||
|
||||
pub struct ClaudeMcpServer {
|
||||
server: McpServer,
|
||||
pub struct ZedMcpServer {
|
||||
server: context_server::listener::McpServer,
|
||||
}
|
||||
|
||||
pub const SERVER_NAME: &str = "zed";
|
||||
|
@ -52,15 +46,36 @@ enum PermissionToolBehavior {
|
|||
Deny,
|
||||
}
|
||||
|
||||
impl ClaudeMcpServer {
|
||||
#[derive(Clone)]
|
||||
pub struct EnabledTools {
|
||||
pub read: bool,
|
||||
pub edit: bool,
|
||||
pub permission: bool,
|
||||
}
|
||||
|
||||
impl Default for EnabledTools {
|
||||
fn default() -> Self {
|
||||
EnabledTools {
|
||||
read: true,
|
||||
edit: true,
|
||||
permission: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZedMcpServer {
|
||||
pub async fn new(
|
||||
delegate: watch::Receiver<Option<AcpClientDelegate>>,
|
||||
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
enabled_tools: EnabledTools,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut mcp_server = McpServer::new(cx).await?;
|
||||
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
||||
mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
|
||||
mcp_server.handle_request::<requests::ListTools>({
|
||||
let enabled_tools = enabled_tools.clone();
|
||||
move |_, cx| Self::handle_list_tools(enabled_tools.clone(), cx)
|
||||
});
|
||||
mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
|
||||
Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx)
|
||||
});
|
||||
|
@ -70,9 +85,7 @@ impl ClaudeMcpServer {
|
|||
|
||||
pub fn server_config(&self) -> Result<McpServerConfig> {
|
||||
let zed_path = std::env::current_exe()
|
||||
.context("finding current executable path for use in mcp_server")?
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
.context("finding current executable path for use in mcp_server")?;
|
||||
|
||||
Ok(McpServerConfig {
|
||||
command: zed_path,
|
||||
|
@ -107,17 +120,17 @@ impl ClaudeMcpServer {
|
|||
})
|
||||
}
|
||||
|
||||
fn handle_list_tools(_: (), cx: &App) -> Task<Result<ListToolsResponse>> {
|
||||
fn handle_list_tools(enabled: EnabledTools, cx: &App) -> Task<Result<ListToolsResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
Ok(ListToolsResponse {
|
||||
tools: vec![
|
||||
Tool {
|
||||
tools: [
|
||||
enabled.permission.then(|| Tool {
|
||||
name: PERMISSION_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(PermissionToolParams).into(),
|
||||
description: None,
|
||||
annotations: None,
|
||||
},
|
||||
Tool {
|
||||
}),
|
||||
enabled.read.then(|| Tool {
|
||||
name: READ_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(ReadToolParams).into(),
|
||||
description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()),
|
||||
|
@ -130,8 +143,8 @@ impl ClaudeMcpServer {
|
|||
// true or false seem too strong, let's try a none.
|
||||
idempotent_hint: None,
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
}),
|
||||
enabled.edit.then(|| Tool {
|
||||
name: EDIT_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(EditToolParams).into(),
|
||||
description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()),
|
||||
|
@ -142,8 +155,8 @@ impl ClaudeMcpServer {
|
|||
open_world_hint: Some(false),
|
||||
idempotent_hint: Some(false),
|
||||
}),
|
||||
},
|
||||
],
|
||||
}),
|
||||
].into_iter().flatten().collect(),
|
||||
next_cursor: None,
|
||||
meta: None,
|
||||
})
|
||||
|
@ -294,3 +307,18 @@ impl ClaudeMcpServer {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct McpConfig {
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct McpServerConfig {
|
||||
pub command: PathBuf,
|
||||
pub args: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
|
@ -13,6 +13,7 @@ pub fn init(cx: &mut App) {
|
|||
pub struct AllAgentServersSettings {
|
||||
pub gemini: Option<AgentServerSettings>,
|
||||
pub claude: Option<AgentServerSettings>,
|
||||
pub codex: Option<AgentServerSettings>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
|
@ -29,13 +30,21 @@ impl settings::Settings for AllAgentServersSettings {
|
|||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
let mut settings = AllAgentServersSettings::default();
|
||||
|
||||
for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
|
||||
for AllAgentServersSettings {
|
||||
gemini,
|
||||
claude,
|
||||
codex,
|
||||
} in sources.defaults_and_customizations()
|
||||
{
|
||||
if gemini.is_some() {
|
||||
settings.gemini = gemini.clone();
|
||||
}
|
||||
if claude.is_some() {
|
||||
settings.claude = claude.clone();
|
||||
}
|
||||
if codex.is_some() {
|
||||
settings.codex = codex.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
|
|
|
@ -1126,21 +1126,27 @@ impl AcpThreadView {
|
|||
))
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Execute {
|
||||
command,
|
||||
command: _,
|
||||
root_command,
|
||||
description,
|
||||
} => confirmation_container
|
||||
.child(v_flex().px_2().pb_1p5().child(command.clone()).children(
|
||||
description.clone().map(|description| {
|
||||
self.render_markdown(description, default_markdown_style(false, window, cx))
|
||||
.child(
|
||||
v_flex()
|
||||
.px_2()
|
||||
.pb_1p5()
|
||||
.children(description.clone().map(|description| {
|
||||
self.render_markdown(
|
||||
description,
|
||||
default_markdown_style(false, window, cx),
|
||||
)
|
||||
.on_url_click({
|
||||
let workspace = self.workspace.clone();
|
||||
move |text, window, cx| {
|
||||
Self::open_link(text, &workspace, window, cx);
|
||||
}
|
||||
})
|
||||
}),
|
||||
))
|
||||
})),
|
||||
)
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[AlwaysAllowOption {
|
||||
|
|
|
@ -1980,6 +1980,13 @@ impl AgentPanel {
|
|||
);
|
||||
}),
|
||||
)
|
||||
.action(
|
||||
"New Codex Thread",
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Codex),
|
||||
}
|
||||
.boxed_clone(),
|
||||
)
|
||||
});
|
||||
menu
|
||||
}))
|
||||
|
|
|
@ -150,6 +150,7 @@ enum ExternalAgent {
|
|||
#[default]
|
||||
Gemini,
|
||||
ClaudeCode,
|
||||
Codex,
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
|
@ -157,6 +158,7 @@ impl ExternalAgent {
|
|||
match self {
|
||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||
ExternalAgent::Codex => Rc::new(agent_servers::Codex),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt, StreamExt, channel::oneshot, select};
|
||||
use futures::{FutureExt, StreamExt, channel::oneshot, future, select};
|
||||
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
|
||||
use parking_lot::Mutex;
|
||||
use postage::barrier;
|
||||
|
@ -10,15 +10,19 @@ use smol::channel;
|
|||
use std::{
|
||||
fmt,
|
||||
path::PathBuf,
|
||||
pin::pin,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicI32, Ordering::SeqCst},
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use util::TryFutureExt;
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
|
||||
use crate::transport::{StdioTransport, Transport};
|
||||
use crate::{
|
||||
transport::{StdioTransport, Transport},
|
||||
types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
|
||||
};
|
||||
|
||||
const JSON_RPC_VERSION: &str = "2.0";
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
@ -32,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603;
|
|||
|
||||
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
|
||||
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
|
||||
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
|
@ -46,6 +51,7 @@ pub(crate) struct Client {
|
|||
outbound_tx: channel::Sender<String>,
|
||||
name: Arc<str>,
|
||||
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
|
||||
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
|
||||
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[allow(dead_code)]
|
||||
|
@ -78,6 +84,15 @@ pub struct Request<'a, T> {
|
|||
pub params: T,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AnyRequest<'a> {
|
||||
pub jsonrpc: &'a str,
|
||||
pub id: RequestId,
|
||||
pub method: &'a str,
|
||||
#[serde(skip_serializing_if = "is_null_value")]
|
||||
pub params: Option<&'a RawValue>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AnyResponse<'a> {
|
||||
jsonrpc: &'a str,
|
||||
|
@ -176,15 +191,23 @@ impl Client {
|
|||
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
|
||||
let response_handlers =
|
||||
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
|
||||
let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
|
||||
|
||||
let receive_input_task = cx.spawn({
|
||||
let notification_handlers = notification_handlers.clone();
|
||||
let response_handlers = response_handlers.clone();
|
||||
let request_handlers = request_handlers.clone();
|
||||
let transport = transport.clone();
|
||||
async move |cx| {
|
||||
Self::handle_input(transport, notification_handlers, response_handlers, cx)
|
||||
.log_err()
|
||||
.await
|
||||
Self::handle_input(
|
||||
transport,
|
||||
notification_handlers,
|
||||
request_handlers,
|
||||
response_handlers,
|
||||
cx,
|
||||
)
|
||||
.log_err()
|
||||
.await
|
||||
}
|
||||
});
|
||||
let receive_err_task = cx.spawn({
|
||||
|
@ -211,6 +234,7 @@ impl Client {
|
|||
server_id,
|
||||
notification_handlers,
|
||||
response_handlers,
|
||||
request_handlers,
|
||||
name: server_name,
|
||||
next_id: Default::default(),
|
||||
outbound_tx,
|
||||
|
@ -230,13 +254,24 @@ impl Client {
|
|||
async fn handle_input(
|
||||
transport: Arc<dyn Transport>,
|
||||
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
|
||||
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
|
||||
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut receiver = transport.receive();
|
||||
|
||||
while let Some(message) = receiver.next().await {
|
||||
if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
|
||||
log::trace!("recv: {}", &message);
|
||||
if let Ok(request) = serde_json::from_str::<AnyRequest>(&message) {
|
||||
let mut request_handlers = request_handlers.lock();
|
||||
if let Some(handler) = request_handlers.get_mut(request.method) {
|
||||
handler(
|
||||
request.id,
|
||||
request.params.unwrap_or(RawValue::NULL),
|
||||
cx.clone(),
|
||||
);
|
||||
}
|
||||
} else if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
|
||||
if let Some(handlers) = response_handlers.lock().as_mut() {
|
||||
if let Some(handler) = handlers.remove(&response.id) {
|
||||
handler(Ok(message.to_string()));
|
||||
|
@ -247,6 +282,8 @@ impl Client {
|
|||
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
|
||||
handler(notification.params.unwrap_or(Value::Null), cx.clone());
|
||||
}
|
||||
} else {
|
||||
log::error!("Unhandled JSON from context_server: {}", message);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -294,6 +331,17 @@ impl Client {
|
|||
&self,
|
||||
method: &str,
|
||||
params: impl Serialize,
|
||||
) -> Result<T> {
|
||||
self.request_with(method, params, None, Some(REQUEST_TIMEOUT))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn request_with<T: DeserializeOwned>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: impl Serialize,
|
||||
cancel_rx: Option<oneshot::Receiver<()>>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<T> {
|
||||
let id = self.next_id.fetch_add(1, SeqCst);
|
||||
let request = serde_json::to_string(&Request {
|
||||
|
@ -329,7 +377,25 @@ impl Client {
|
|||
handle_response?;
|
||||
send?;
|
||||
|
||||
let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
|
||||
let mut timeout_fut = pin!(
|
||||
if let Some(timeout) = timeout {
|
||||
future::Either::Left(executor.timer(timeout))
|
||||
} else {
|
||||
future::Either::Right(future::pending())
|
||||
}
|
||||
.fuse()
|
||||
);
|
||||
|
||||
let mut cancel_fut = pin!(
|
||||
match cancel_rx {
|
||||
Some(rx) => future::Either::Left(async {
|
||||
rx.await.log_err();
|
||||
}),
|
||||
None => future::Either::Right(future::pending()),
|
||||
}
|
||||
.fuse()
|
||||
);
|
||||
|
||||
select! {
|
||||
response = rx.fuse() => {
|
||||
let elapsed = started.elapsed();
|
||||
|
@ -348,8 +414,18 @@ impl Client {
|
|||
Err(_) => anyhow::bail!("cancelled")
|
||||
}
|
||||
}
|
||||
_ = timeout => {
|
||||
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
|
||||
_ = cancel_fut => {
|
||||
self.notify(
|
||||
Cancelled::METHOD,
|
||||
ClientNotification::Cancelled(CancelledParams {
|
||||
request_id: RequestId::Int(id),
|
||||
reason: None
|
||||
})
|
||||
).log_err();
|
||||
anyhow::bail!("Request cancelled")
|
||||
}
|
||||
_ = timeout_fut => {
|
||||
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout);
|
||||
anyhow::bail!("Context server request timeout");
|
||||
}
|
||||
}
|
||||
|
@ -377,6 +453,79 @@ impl Client {
|
|||
.lock()
|
||||
.insert(method, Box::new(f));
|
||||
}
|
||||
|
||||
pub fn on_request<R: crate::types::Request, F>(&self, mut f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task<Result<R::Response>>,
|
||||
{
|
||||
let outbound_tx = self.outbound_tx.clone();
|
||||
self.request_handlers.lock().insert(
|
||||
R::METHOD,
|
||||
Box::new(move |id, json, cx| {
|
||||
let outbound_tx = outbound_tx.clone();
|
||||
match serde_json::from_str(json.get()) {
|
||||
Ok(req) => {
|
||||
let task = f(req, cx.clone());
|
||||
cx.foreground_executor()
|
||||
.spawn(async move {
|
||||
match task.await {
|
||||
Ok(res) => {
|
||||
outbound_tx
|
||||
.send(
|
||||
serde_json::to_string(&Response {
|
||||
jsonrpc: JSON_RPC_VERSION,
|
||||
id,
|
||||
value: CspResult::Ok(Some(res)),
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
Err(e) => {
|
||||
outbound_tx
|
||||
.send(
|
||||
serde_json::to_string(&Response {
|
||||
jsonrpc: JSON_RPC_VERSION,
|
||||
id,
|
||||
value: CspResult::<()>::Error(Some(Error {
|
||||
code: -1, // todo!()
|
||||
message: format!("{e}"),
|
||||
})),
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
Err(e) => {
|
||||
cx.foreground_executor()
|
||||
.spawn(async move {
|
||||
outbound_tx
|
||||
.send(
|
||||
serde_json::to_string(&Response {
|
||||
jsonrpc: JSON_RPC_VERSION,
|
||||
id,
|
||||
value: CspResult::<()>::Error(Some(Error {
|
||||
code: -1, // todo!()
|
||||
message: format!("{e}"),
|
||||
})),
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ContextServerId {
|
||||
|
|
|
@ -5,7 +5,12 @@
|
|||
//! read/write messages and the types from types.rs for serialization/deserialization
|
||||
//! of messages.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{AsyncApp, Task};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::client::Client;
|
||||
use crate::types::{self, Notification, Request};
|
||||
|
@ -95,7 +100,32 @@ impl InitializedContextServerProtocol {
|
|||
self.inner.request(T::METHOD, params).await
|
||||
}
|
||||
|
||||
pub async fn request_with<T: Request>(
|
||||
&self,
|
||||
params: T::Params,
|
||||
cancel_rx: Option<oneshot::Receiver<()>>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<T::Response> {
|
||||
self.inner
|
||||
.request_with(T::METHOD, params, cancel_rx, timeout)
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
|
||||
self.inner.notify(T::METHOD, params)
|
||||
}
|
||||
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
{
|
||||
self.inner.on_notification(method, f);
|
||||
}
|
||||
|
||||
pub fn on_request<R: crate::types::Request, F>(&self, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task<Result<R::Response>>,
|
||||
{
|
||||
self.inner.on_request::<R, F>(f);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ use serde::de::DeserializeOwned;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
use crate::client::RequestId;
|
||||
|
||||
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
|
||||
pub const VERSION_2024_11_05: &str = "2024-11-05";
|
||||
|
||||
|
@ -100,6 +102,7 @@ pub mod notifications {
|
|||
notification!("notifications/initialized", Initialized, ());
|
||||
notification!("notifications/progress", Progress, ProgressParams);
|
||||
notification!("notifications/message", Message, MessageParams);
|
||||
notification!("notifications/cancelled", Cancelled, CancelledParams);
|
||||
notification!(
|
||||
"notifications/resources/updated",
|
||||
ResourcesUpdated,
|
||||
|
@ -617,11 +620,14 @@ pub enum ClientNotification {
|
|||
Initialized,
|
||||
Progress(ProgressParams),
|
||||
RootsListChanged,
|
||||
Cancelled {
|
||||
request_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reason: Option<String>,
|
||||
},
|
||||
Cancelled(CancelledParams),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CancelledParams {
|
||||
pub request_id: RequestId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
|
|
@ -160,6 +160,7 @@ zed_actions.workspace = true
|
|||
zeta.workspace = true
|
||||
zlog.workspace = true
|
||||
zlog_settings.workspace = true
|
||||
erased-serde = "0.4.6"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
windows.workspace = true
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue