Refactor to use new ACP crate (#35043)

This will prepare us for running the protocol over MCP

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Agus Zubiaga 2025-07-24 14:39:29 -03:00 committed by GitHub
parent 45ddf32a1d
commit 2d0f10c48a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1830 additions and 1748 deletions

15
Cargo.lock generated
View file

@ -6,6 +6,7 @@ version = 4
name = "acp_thread" name = "acp_thread"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"agent-client-protocol",
"agentic-coding-protocol", "agentic-coding-protocol",
"anyhow", "anyhow",
"assistant_tool", "assistant_tool",
@ -135,11 +136,23 @@ dependencies = [
"zstd", "zstd",
] ]
[[package]]
name = "agent-client-protocol"
version = "0.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fb7f39671e02f8a1aeb625652feae40b6fc2597baaa97e028a98863477aecbd"
dependencies = [
"schemars",
"serde",
"serde_json",
]
[[package]] [[package]]
name = "agent_servers" name = "agent_servers"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"acp_thread", "acp_thread",
"agent-client-protocol",
"agentic-coding-protocol", "agentic-coding-protocol",
"anyhow", "anyhow",
"collections", "collections",
@ -195,9 +208,9 @@ version = "0.1.0"
dependencies = [ dependencies = [
"acp_thread", "acp_thread",
"agent", "agent",
"agent-client-protocol",
"agent_servers", "agent_servers",
"agent_settings", "agent_settings",
"agentic-coding-protocol",
"ai_onboarding", "ai_onboarding",
"anyhow", "anyhow",
"assistant_context", "assistant_context",

View file

@ -413,6 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" }
# #
agentic-coding-protocol = "0.0.10" agentic-coding-protocol = "0.0.10"
agent-client-protocol = "0.0.10"
aho-corasick = "1.1" aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14" any_vec = "0.14"

View file

@ -16,6 +16,7 @@ doctest = false
test-support = ["gpui/test-support", "project/test-support"] test-support = ["gpui/test-support", "project/test-support"]
[dependencies] [dependencies]
agent-client-protocol.workspace = true
agentic-coding-protocol.workspace = true agentic-coding-protocol.workspace = true
anyhow.workspace = true anyhow.workspace = true
assistant_tool.workspace = true assistant_tool.workspace = true

File diff suppressed because it is too large Load diff

View file

@ -1,20 +1,26 @@
use agentic_coding_protocol as acp; use std::{path::Path, rc::Rc};
use agent_client_protocol as acp;
use anyhow::Result; use anyhow::Result;
use futures::future::{FutureExt as _, LocalBoxFuture}; use gpui::{AsyncApp, Entity, Task};
use project::Project;
use ui::App;
use crate::AcpThread;
pub trait AgentConnection { pub trait AgentConnection {
fn request_any( fn name(&self) -> &'static str;
&self,
params: acp::AnyAgentRequest,
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>>;
}
impl AgentConnection for acp::AgentConnection { fn new_thread(
fn request_any( self: Rc<Self>,
&self, project: Entity<Project>,
params: acp::AnyAgentRequest, cwd: &Path,
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> { cx: &mut AsyncApp,
let task = self.request_any(params); ) -> Task<Result<Entity<AcpThread>>>;
async move { Ok(task.await?) }.boxed_local()
} fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task<Result<()>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
} }

View file

@ -0,0 +1,461 @@
// Translates old acp agents into the new schema
use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
use anyhow::{Context as _, Result};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
use ui::App;
use crate::{AcpThread, AcpThreadEvent, AgentConnection, ToolCallContent, ToolCallStatus};
#[derive(Clone)]
pub struct OldAcpClientDelegate {
thread: Rc<RefCell<WeakEntity<AcpThread>>>,
cx: AsyncApp,
next_tool_call_id: Rc<RefCell<u64>>,
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
}
impl OldAcpClientDelegate {
pub fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
Self {
thread,
cx,
next_tool_call_id: Rc::new(RefCell::new(0)),
}
}
}
impl acp_old::Client for OldAcpClientDelegate {
async fn stream_assistant_message_chunk(
&self,
params: acp_old::StreamAssistantMessageChunkParams,
) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread
.borrow()
.update(cx, |thread, cx| match params.chunk {
acp_old::AssistantMessageChunk::Text { text } => {
thread.push_assistant_chunk(text.into(), false, cx)
}
acp_old::AssistantMessageChunk::Thought { thought } => {
thread.push_assistant_chunk(thought.into(), true, cx)
}
})
.ok();
})?;
Ok(())
}
async fn request_tool_call_confirmation(
&self,
request: acp_old::RequestToolCallConfirmationParams,
) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
let cx = &mut self.cx.clone();
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
self.next_tool_call_id.replace(old_acp_id);
let tool_call = into_new_tool_call(
acp::ToolCallId(old_acp_id.to_string().into()),
request.tool_call,
);
let mut options = match request.confirmation {
acp_old::ToolCallConfirmation::Edit { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow Edits".to_string(),
)],
acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", root_command),
)],
acp_old::ToolCallConfirmation::Mcp {
server_name,
tool_name,
..
} => vec![
(
acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", server_name),
),
(
acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool,
acp::PermissionOptionKind::AllowAlways,
format!("Always Allow {}", tool_name),
),
],
acp_old::ToolCallConfirmation::Fetch { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow".to_string(),
)],
acp_old::ToolCallConfirmation::Other { .. } => vec![(
acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
acp::PermissionOptionKind::AllowAlways,
"Always Allow".to_string(),
)],
};
options.extend([
(
acp_old::ToolCallConfirmationOutcome::Allow,
acp::PermissionOptionKind::AllowOnce,
"Allow".to_string(),
),
(
acp_old::ToolCallConfirmationOutcome::Reject,
acp::PermissionOptionKind::RejectOnce,
"Reject".to_string(),
),
]);
let mut outcomes = Vec::with_capacity(options.len());
let mut acp_options = Vec::with_capacity(options.len());
for (index, (outcome, kind, label)) in options.into_iter().enumerate() {
outcomes.push(outcome);
acp_options.push(acp::PermissionOption {
id: acp::PermissionOptionId(index.to_string().into()),
label,
kind,
})
}
let response = cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.request_tool_call_permission(tool_call, acp_options, cx)
})
})?
.context("Failed to update thread")?
.await;
let outcome = match response {
Ok(option_id) => outcomes[option_id.0.parse::<usize>().unwrap_or(0)],
Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel,
};
Ok(acp_old::RequestToolCallConfirmationResponse {
id: acp_old::ToolCallId(old_acp_id),
outcome: outcome,
})
}
async fn push_tool_call(
&self,
request: acp_old::PushToolCallParams,
) -> Result<acp_old::PushToolCallResponse, acp_old::Error> {
let cx = &mut self.cx.clone();
let old_acp_id = *self.next_tool_call_id.borrow() + 1;
self.next_tool_call_id.replace(old_acp_id);
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.upsert_tool_call(
into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request),
cx,
)
})
})?
.context("Failed to update thread")?;
Ok(acp_old::PushToolCallResponse {
id: acp_old::ToolCallId(old_acp_id),
})
}
async fn update_tool_call(
&self,
request: acp_old::UpdateToolCallParams,
) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
let languages = thread.project.read(cx).languages().clone();
if let Some((ix, tool_call)) = thread
.tool_call_mut(&acp::ToolCallId(request.tool_call_id.0.to_string().into()))
{
tool_call.status = ToolCallStatus::Allowed {
status: into_new_tool_call_status(request.status),
};
tool_call.content = request
.content
.into_iter()
.map(|content| {
ToolCallContent::from_acp(
into_new_tool_call_content(content),
languages.clone(),
cx,
)
})
.collect();
cx.emit(AcpThreadEvent::EntryUpdated(ix));
anyhow::Ok(())
} else {
anyhow::bail!("Tool call not found")
}
})
})?
.context("Failed to update thread")??;
Ok(())
}
async fn update_plan(&self, request: acp_old::UpdatePlanParams) -> Result<(), acp_old::Error> {
let cx = &mut self.cx.clone();
cx.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.update_plan(
acp::Plan {
entries: request
.entries
.into_iter()
.map(into_new_plan_entry)
.collect(),
},
cx,
)
})
})?
.context("Failed to update thread")?;
Ok(())
}
async fn read_text_file(
&self,
acp_old::ReadTextFileParams { path, line, limit }: acp_old::ReadTextFileParams,
) -> Result<acp_old::ReadTextFileResponse, acp_old::Error> {
let content = self
.cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.read_text_file(path, line, limit, false, cx)
})
})?
.context("Failed to update thread")?
.await?;
Ok(acp_old::ReadTextFileResponse { content })
}
async fn write_text_file(
&self,
acp_old::WriteTextFileParams { path, content }: acp_old::WriteTextFileParams,
) -> Result<(), acp_old::Error> {
self.cx
.update(|cx| {
self.thread
.borrow()
.update(cx, |thread, cx| thread.write_text_file(path, content, cx))
})?
.context("Failed to update thread")?
.await?;
Ok(())
}
}
fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) -> acp::ToolCall {
acp::ToolCall {
id: id,
label: request.label,
kind: acp_kind_from_old_icon(request.icon),
status: acp::ToolCallStatus::InProgress,
content: request
.content
.into_iter()
.map(into_new_tool_call_content)
.collect(),
locations: request
.locations
.into_iter()
.map(into_new_tool_call_location)
.collect(),
}
}
fn acp_kind_from_old_icon(icon: acp_old::Icon) -> acp::ToolKind {
match icon {
acp_old::Icon::FileSearch => acp::ToolKind::Search,
acp_old::Icon::Folder => acp::ToolKind::Search,
acp_old::Icon::Globe => acp::ToolKind::Search,
acp_old::Icon::Hammer => acp::ToolKind::Other,
acp_old::Icon::LightBulb => acp::ToolKind::Think,
acp_old::Icon::Pencil => acp::ToolKind::Edit,
acp_old::Icon::Regex => acp::ToolKind::Search,
acp_old::Icon::Terminal => acp::ToolKind::Execute,
}
}
fn into_new_tool_call_status(status: acp_old::ToolCallStatus) -> acp::ToolCallStatus {
match status {
acp_old::ToolCallStatus::Running => acp::ToolCallStatus::InProgress,
acp_old::ToolCallStatus::Finished => acp::ToolCallStatus::Completed,
acp_old::ToolCallStatus::Error => acp::ToolCallStatus::Failed,
}
}
fn into_new_tool_call_content(content: acp_old::ToolCallContent) -> acp::ToolCallContent {
match content {
acp_old::ToolCallContent::Markdown { markdown } => acp::ToolCallContent::ContentBlock {
content: acp::ContentBlock::Text(acp::TextContent {
annotations: None,
text: markdown,
}),
},
acp_old::ToolCallContent::Diff { diff } => acp::ToolCallContent::Diff {
diff: into_new_diff(diff),
},
}
}
fn into_new_diff(diff: acp_old::Diff) -> acp::Diff {
acp::Diff {
path: diff.path,
old_text: diff.old_text,
new_text: diff.new_text,
}
}
fn into_new_tool_call_location(location: acp_old::ToolCallLocation) -> acp::ToolCallLocation {
acp::ToolCallLocation {
path: location.path,
line: location.line,
}
}
fn into_new_plan_entry(entry: acp_old::PlanEntry) -> acp::PlanEntry {
acp::PlanEntry {
content: entry.content,
priority: into_new_plan_priority(entry.priority),
status: into_new_plan_status(entry.status),
}
}
fn into_new_plan_priority(priority: acp_old::PlanEntryPriority) -> acp::PlanEntryPriority {
match priority {
acp_old::PlanEntryPriority::Low => acp::PlanEntryPriority::Low,
acp_old::PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium,
acp_old::PlanEntryPriority::High => acp::PlanEntryPriority::High,
}
}
fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatus {
match status {
acp_old::PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending,
acp_old::PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress,
acp_old::PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed,
}
}
#[derive(Debug)]
pub struct Unauthenticated;
impl Error for Unauthenticated {}
impl fmt::Display for Unauthenticated {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unauthenticated")
}
}
pub struct OldAcpAgentConnection {
pub name: &'static str,
pub connection: acp_old::AgentConnection,
pub child_status: Task<Result<()>>,
}
impl AgentConnection for OldAcpAgentConnection {
fn name(&self) -> &'static str {
self.name
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let task = self.connection.request_any(
acp_old::InitializeParams {
protocol_version: acp_old::ProtocolVersion::latest(),
}
.into_any(),
);
cx.spawn(async move |cx| {
let result = task.await?;
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
anyhow::bail!(Unauthenticated)
}
cx.update(|cx| {
let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into());
AcpThread::new(self.clone(), project, session_id, cx)
});
thread
})
})
}
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
let task = self
.connection
.request_any(acp_old::AuthenticateParams.into_any());
cx.foreground_executor().spawn(async move {
task.await?;
Ok(())
})
}
fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task<Result<()>> {
let chunks = params
.prompt
.into_iter()
.filter_map(|block| match block {
acp::ContentBlock::Text(text) => {
Some(acp_old::UserMessageChunk::Text { text: text.text })
}
acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path {
path: link.uri.into(),
}),
_ => None,
})
.collect();
let task = self
.connection
.request_any(acp_old::SendUserMessageParams { chunks }.into_any());
cx.foreground_executor().spawn(async move {
task.await?;
anyhow::Ok(())
})
}
fn cancel(&self, _session_id: &acp::SessionId, cx: &mut App) {
let task = self
.connection
.request_any(acp_old::CancelSendMessageParams.into_any());
cx.foreground_executor()
.spawn(async move {
task.await?;
anyhow::Ok(())
})
.detach_and_log_err(cx)
}
}

View file

@ -18,6 +18,7 @@ doctest = false
[dependencies] [dependencies]
acp_thread.workspace = true acp_thread.workspace = true
agent-client-protocol.workspace = true
agentic-coding-protocol.workspace = true agentic-coding-protocol.workspace = true
anyhow.workspace = true anyhow.workspace = true
collections.workspace = true collections.workspace = true

View file

@ -1,7 +1,6 @@
mod claude; mod claude;
mod gemini; mod gemini;
mod settings; mod settings;
mod stdio_agent_server;
#[cfg(test)] #[cfg(test)]
mod e2e_tests; mod e2e_tests;
@ -9,9 +8,8 @@ mod e2e_tests;
pub use claude::*; pub use claude::*;
pub use gemini::*; pub use gemini::*;
pub use settings::*; pub use settings::*;
pub use stdio_agent_server::*;
use acp_thread::AcpThread; use acp_thread::AgentConnection;
use anyhow::Result; use anyhow::Result;
use collections::HashMap; use collections::HashMap;
use gpui::{App, AsyncApp, Entity, SharedString, Task}; use gpui::{App, AsyncApp, Entity, SharedString, Task};
@ -20,6 +18,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
rc::Rc,
sync::Arc, sync::Arc,
}; };
use util::ResultExt as _; use util::ResultExt as _;
@ -33,14 +32,14 @@ pub trait AgentServer: Send {
fn name(&self) -> &'static str; fn name(&self) -> &'static str;
fn empty_state_headline(&self) -> &'static str; fn empty_state_headline(&self) -> &'static str;
fn empty_state_message(&self) -> &'static str; fn empty_state_message(&self) -> &'static str;
fn supports_always_allow(&self) -> bool;
fn new_thread( fn connect(
&self, &self,
// these will go away when old_acp is fully removed
root_dir: &Path, root_dir: &Path,
project: &Entity<Project>, project: &Entity<Project>,
cx: &mut App, cx: &mut App,
) -> Task<Result<Entity<AcpThread>>>; ) -> Task<Result<Rc<dyn AgentConnection>>>;
} }
impl std::fmt::Debug for AgentServerCommand { impl std::fmt::Debug for AgentServerCommand {

View file

@ -1,5 +1,5 @@
mod mcp_server; mod mcp_server;
mod tools; pub mod tools;
use collections::HashMap; use collections::HashMap;
use project::Project; use project::Project;
@ -12,28 +12,24 @@ use std::pin::pin;
use std::rc::Rc; use std::rc::Rc;
use uuid::Uuid; use uuid::Uuid;
use agentic_coding_protocol::{ use agent_client_protocol as acp;
self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion,
StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams,
};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use futures::channel::oneshot; use futures::channel::oneshot;
use futures::future::LocalBoxFuture; use futures::{AsyncBufReadExt, AsyncWriteExt};
use futures::{AsyncBufReadExt, AsyncWriteExt, SinkExt};
use futures::{ use futures::{
AsyncRead, AsyncWrite, FutureExt, StreamExt, AsyncRead, AsyncWrite, FutureExt, StreamExt,
channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
io::BufReader, io::BufReader,
select_biased, select_biased,
}; };
use gpui::{App, AppContext, Entity, Task}; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::ResultExt; use util::ResultExt;
use crate::claude::mcp_server::ClaudeMcpServer; use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
use crate::claude::tools::ClaudeTool; use crate::claude::tools::ClaudeTool;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection}; use acp_thread::{AcpThread, AgentConnection};
#[derive(Clone)] #[derive(Clone)]
pub struct ClaudeCode; pub struct ClaudeCode;
@ -55,29 +51,57 @@ impl AgentServer for ClaudeCode {
ui::IconName::AiClaude ui::IconName::AiClaude
} }
fn supports_always_allow(&self) -> bool { fn connect(
false &self,
_root_dir: &Path,
_project: &Entity<Project>,
_cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let connection = ClaudeAgentConnection {
sessions: Default::default(),
};
Task::ready(Ok(Rc::new(connection) as _))
}
}
#[cfg(unix)]
fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> {
let pid = nix::unistd::Pid::from_raw(pid);
nix::sys::signal::kill(pid, nix::sys::signal::SIGINT)
.map_err(|e| anyhow!("Failed to interrupt process: {}", e))
}
#[cfg(windows)]
fn send_interrupt(_pid: i32) -> anyhow::Result<()> {
panic!("Cancel not implemented on Windows")
}
struct ClaudeAgentConnection {
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
}
impl AgentConnection for ClaudeAgentConnection {
fn name(&self) -> &'static str {
ClaudeCode.name()
} }
fn new_thread( fn new_thread(
&self, self: Rc<Self>,
root_dir: &Path, project: Entity<Project>,
project: &Entity<Project>, cwd: &Path,
cx: &mut App, cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> { ) -> Task<Result<Entity<AcpThread>>> {
let project = project.clone(); let cwd = cwd.to_owned();
let root_dir = root_dir.to_path_buf();
let title = self.name().into();
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let (mut delegate_tx, delegate_rx) = watch::channel(None); let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let tool_id_map = Rc::new(RefCell::new(HashMap::default())); let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
let mut mcp_servers = HashMap::default(); let mut mcp_servers = HashMap::default();
mcp_servers.insert( mcp_servers.insert(
mcp_server::SERVER_NAME.to_string(), mcp_server::SERVER_NAME.to_string(),
mcp_server.server_config()?, permission_mcp_server.server_config()?,
); );
let mcp_config = McpConfig { mcp_servers }; let mcp_config = McpConfig { mcp_servers };
@ -104,177 +128,180 @@ impl AgentServer for ClaudeCode {
let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
let (cancel_tx, mut cancel_rx) = mpsc::unbounded::<oneshot::Sender<Result<()>>>(); let (cancel_tx, mut cancel_rx) = mpsc::unbounded::<oneshot::Sender<Result<()>>>();
let session_id = Uuid::new_v4(); let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
log::trace!("Starting session with id: {}", session_id); log::trace!("Starting session with id: {}", session_id);
cx.background_spawn(async move { cx.background_spawn({
let mut outgoing_rx = Some(outgoing_rx); let session_id = session_id.clone();
let mut mode = ClaudeSessionMode::Start; async move {
let mut outgoing_rx = Some(outgoing_rx);
let mut mode = ClaudeSessionMode::Start;
loop { loop {
let mut child = let mut child = spawn_claude(
spawn_claude(&command, mode, session_id, &mcp_config_path, &root_dir) &command,
.await?; mode,
mode = ClaudeSessionMode::Resume; session_id.clone(),
&mcp_config_path,
let pid = child.id(); &cwd,
log::trace!("Spawned (pid: {})", pid);
let mut io_fut = pin!(
ClaudeAgentConnection::handle_io(
outgoing_rx.take().unwrap(),
incoming_message_tx.clone(),
child.stdin.take().unwrap(),
child.stdout.take().unwrap(),
) )
.fuse() .await?;
); mode = ClaudeSessionMode::Resume;
select_biased! { let pid = child.id();
done_tx = cancel_rx.next() => { log::trace!("Spawned (pid: {})", pid);
if let Some(done_tx) = done_tx {
log::trace!("Interrupted (pid: {})", pid); let mut io_fut = pin!(
let result = send_interrupt(pid as i32); ClaudeAgentSession::handle_io(
outgoing_rx.replace(io_fut.await?); outgoing_rx.take().unwrap(),
done_tx.send(result).log_err(); incoming_message_tx.clone(),
continue; child.stdin.take().unwrap(),
child.stdout.take().unwrap(),
)
.fuse()
);
select_biased! {
done_tx = cancel_rx.next() => {
if let Some(done_tx) = done_tx {
log::trace!("Interrupted (pid: {})", pid);
let result = send_interrupt(pid as i32);
outgoing_rx.replace(io_fut.await?);
done_tx.send(result).log_err();
continue;
}
}
result = io_fut => {
result?;
} }
} }
result = io_fut => {
result?; log::trace!("Stopped (pid: {})", pid);
} break;
} }
log::trace!("Stopped (pid: {})", pid); drop(mcp_config_path);
break; anyhow::Ok(())
} }
drop(mcp_config_path);
anyhow::Ok(())
}) })
.detach(); .detach();
cx.new(|cx| { let end_turn_tx = Rc::new(RefCell::new(None));
let end_turn_tx = Rc::new(RefCell::new(None)); let handler_task = cx.spawn({
let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()); let end_turn_tx = end_turn_tx.clone();
delegate_tx.send(Some(delegate.clone())).log_err(); let thread_rx = thread_rx.clone();
async move |cx| {
let handler_task = cx.foreground_executor().spawn({ while let Some(message) = incoming_message_rx.next().await {
let end_turn_tx = end_turn_tx.clone(); ClaudeAgentSession::handle_message(
let tool_id_map = tool_id_map.clone(); thread_rx.clone(),
let delegate = delegate.clone(); message,
async move { end_turn_tx.clone(),
while let Some(message) = incoming_message_rx.next().await { cx,
ClaudeAgentConnection::handle_message( )
delegate.clone(), .await
message,
end_turn_tx.clone(),
tool_id_map.clone(),
)
.await
}
} }
}); }
});
let mut connection = ClaudeAgentConnection { let thread =
delegate, cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
outgoing_tx,
end_turn_tx,
cancel_tx,
session_id,
_handler_task: handler_task,
_mcp_server: None,
};
connection._mcp_server = Some(mcp_server); thread_tx.send(thread.downgrade())?;
acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
}) let session = ClaudeAgentSession {
outgoing_tx,
end_turn_tx,
cancel_tx,
_handler_task: handler_task,
_mcp_server: Some(permission_mcp_server),
};
self.sessions.borrow_mut().insert(session_id, session);
Ok(thread)
}) })
} }
}
#[cfg(unix)] fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> { Task::ready(Err(anyhow!("Authentication not supported")))
let pid = nix::unistd::Pid::from_raw(pid); }
nix::sys::signal::kill(pid, nix::sys::signal::SIGINT) fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task<Result<()>> {
.map_err(|e| anyhow!("Failed to interrupt process: {}", e)) let sessions = self.sessions.borrow();
} let Some(session) = sessions.get(&params.session_id) else {
return Task::ready(Err(anyhow!(
"Attempted to send message to nonexistent session {}",
params.session_id
)));
};
#[cfg(windows)] let (tx, rx) = oneshot::channel();
fn send_interrupt(_pid: i32) -> anyhow::Result<()> { session.end_turn_tx.borrow_mut().replace(tx);
panic!("Cancel not implemented on Windows")
}
impl AgentConnection for ClaudeAgentConnection { let mut content = String::new();
/// Send a request to the agent and wait for a response. for chunk in params.prompt {
fn request_any( match chunk {
&self, acp::ContentBlock::Text(text_content) => {
params: AnyAgentRequest, content.push_str(&text_content.text);
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
let delegate = self.delegate.clone();
let end_turn_tx = self.end_turn_tx.clone();
let outgoing_tx = self.outgoing_tx.clone();
let mut cancel_tx = self.cancel_tx.clone();
let session_id = self.session_id;
async move {
match params {
// todo: consider sending an empty request so we get the init response?
AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
acp::InitializeResponse {
is_authenticated: true,
protocol_version: ProtocolVersion::latest(),
},
)),
AnyAgentRequest::AuthenticateParams(_) => {
Err(anyhow!("Authentication not supported"))
} }
AnyAgentRequest::SendUserMessageParams(message) => { acp::ContentBlock::ResourceLink(resource_link) => {
delegate.clear_completed_plan_entries().await?; content.push_str(&format!("@{}", resource_link.uri));
let (tx, rx) = oneshot::channel();
end_turn_tx.borrow_mut().replace(tx);
let mut content = String::new();
for chunk in message.chunks {
match chunk {
agentic_coding_protocol::UserMessageChunk::Text { text } => {
content.push_str(&text)
}
agentic_coding_protocol::UserMessageChunk::Path { path } => {
content.push_str(&format!("@{path:?}"))
}
}
}
outgoing_tx.unbounded_send(SdkMessage::User {
message: Message {
role: Role::User,
content: Content::UntaggedText(content),
id: None,
model: None,
stop_reason: None,
stop_sequence: None,
usage: None,
},
session_id: Some(session_id),
})?;
rx.await??;
Ok(AnyAgentResult::SendUserMessageResponse(
acp::SendUserMessageResponse,
))
} }
AnyAgentRequest::CancelSendMessageParams(_) => { acp::ContentBlock::Audio(_)
let (done_tx, done_rx) = oneshot::channel(); | acp::ContentBlock::Image(_)
cancel_tx.send(done_tx).await?; | acp::ContentBlock::Resource(_) => {
done_rx.await??; // TODO
Ok(AnyAgentResult::CancelSendMessageResponse(
acp::CancelSendMessageResponse,
))
} }
} }
} }
.boxed_local()
if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
message: Message {
role: Role::User,
content: Content::UntaggedText(content),
id: None,
model: None,
stop_reason: None,
stop_sequence: None,
usage: None,
},
session_id: Some(params.session_id.to_string()),
}) {
return Task::ready(Err(anyhow!(err)));
}
cx.foreground_executor().spawn(async move {
rx.await??;
Ok(())
})
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
let sessions = self.sessions.borrow();
let Some(session) = sessions.get(&session_id) else {
log::warn!("Attempted to cancel nonexistent session {}", session_id);
return;
};
let (done_tx, done_rx) = oneshot::channel();
if session
.cancel_tx
.unbounded_send(done_tx)
.log_err()
.is_some()
{
let end_turn_tx = session.end_turn_tx.clone();
cx.foreground_executor()
.spawn(async move {
done_rx.await??;
if let Some(end_turn_tx) = end_turn_tx.take() {
end_turn_tx.send(Ok(())).ok();
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
} }
} }
@ -287,7 +314,7 @@ enum ClaudeSessionMode {
async fn spawn_claude( async fn spawn_claude(
command: &AgentServerCommand, command: &AgentServerCommand,
mode: ClaudeSessionMode, mode: ClaudeSessionMode,
session_id: Uuid, session_id: acp::SessionId,
mcp_config_path: &Path, mcp_config_path: &Path,
root_dir: &Path, root_dir: &Path,
) -> Result<Child> { ) -> Result<Child> {
@ -327,88 +354,103 @@ async fn spawn_claude(
Ok(child) Ok(child)
} }
struct ClaudeAgentConnection { struct ClaudeAgentSession {
delegate: AcpClientDelegate,
session_id: Uuid,
outgoing_tx: UnboundedSender<SdkMessage>, outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>, end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>, cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>,
_mcp_server: Option<ClaudeMcpServer>, _mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>, _handler_task: Task<()>,
} }
impl ClaudeAgentConnection { impl ClaudeAgentSession {
async fn handle_message( async fn handle_message(
delegate: AcpClientDelegate, mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage, message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>, end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>, cx: &mut AsyncApp,
) { ) {
match message { match message {
SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => { SdkMessage::Assistant {
message,
session_id: _,
}
| SdkMessage::User {
message,
session_id: _,
} => {
let Some(thread) = thread_rx
.recv()
.await
.log_err()
.and_then(|entity| entity.upgrade())
else {
log::error!("Received an SDK message but thread is gone");
return;
};
for chunk in message.content.chunks() { for chunk in message.content.chunks() {
match chunk { match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
delegate thread
.stream_assistant_message_chunk(StreamAssistantMessageChunkParams { .update(cx, |thread, cx| {
chunk: acp::AssistantMessageChunk::Text { text }, thread.push_assistant_chunk(text.into(), false, cx)
}) })
.await
.log_err(); .log_err();
} }
ContentChunk::ToolUse { id, name, input } => { ContentChunk::ToolUse { id, name, input } => {
let claude_tool = ClaudeTool::infer(&name, input); let claude_tool = ClaudeTool::infer(&name, input);
if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { thread
delegate .update(cx, |thread, cx| {
.update_plan(acp::UpdatePlanParams { if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
entries: params.todos.into_iter().map(Into::into).collect(), thread.update_plan(
}) acp::Plan {
.await entries: params
.log_err(); .todos
} else if let Some(resp) = delegate .into_iter()
.push_tool_call(claude_tool.as_acp()) .map(Into::into)
.await .collect(),
.log_err() },
{ cx,
tool_id_map.borrow_mut().insert(id, resp.id); )
} } else {
thread.upsert_tool_call(
claude_tool.as_acp(acp::ToolCallId(id.into())),
cx,
);
}
})
.log_err();
} }
ContentChunk::ToolResult { ContentChunk::ToolResult {
content, content,
tool_use_id, tool_use_id,
} => { } => {
let id = tool_id_map.borrow_mut().remove(&tool_use_id); let content = content.to_string();
if let Some(id) = id { thread
let content = content.to_string(); .update(cx, |thread, cx| {
delegate thread.update_tool_call(
.update_tool_call(UpdateToolCallParams { acp::ToolCallId(tool_use_id.into()),
tool_call_id: id, acp::ToolCallStatus::Completed,
status: acp::ToolCallStatus::Finished, (!content.is_empty()).then(|| vec![content.into()]),
// Don't unset existing content cx,
content: (!content.is_empty()).then_some( )
ToolCallContent::Markdown { })
// For now we only include text content .log_err();
markdown: content,
},
),
})
.await
.log_err();
}
} }
ContentChunk::Image ContentChunk::Image
| ContentChunk::Document | ContentChunk::Document
| ContentChunk::Thinking | ContentChunk::Thinking
| ContentChunk::RedactedThinking | ContentChunk::RedactedThinking
| ContentChunk::WebSearchToolResult => { | ContentChunk::WebSearchToolResult => {
delegate thread
.stream_assistant_message_chunk(StreamAssistantMessageChunkParams { .update(cx, |thread, cx| {
chunk: acp::AssistantMessageChunk::Text { thread.push_assistant_chunk(
text: format!("Unsupported content: {:?}", chunk), format!("Unsupported content: {:?}", chunk).into(),
}, false,
cx,
)
}) })
.await
.log_err(); .log_err();
} }
} }
@ -592,14 +634,14 @@ enum SdkMessage {
Assistant { Assistant {
message: Message, // from Anthropic SDK message: Message, // from Anthropic SDK
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<Uuid>, session_id: Option<String>,
}, },
// A user message // A user message
User { User {
message: Message, // from Anthropic SDK message: Message, // from Anthropic SDK
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<Uuid>, session_id: Option<String>,
}, },
// Emitted as the last message in a conversation // Emitted as the last message in a conversation
@ -661,21 +703,6 @@ enum PermissionMode {
Plan, Plan,
} }
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct McpConfig {
mcp_servers: HashMap<String, McpServerConfig>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct McpServerConfig {
command: String,
args: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
env: Option<HashMap<String, String>>,
}
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use super::*; use super::*;

View file

@ -1,29 +1,22 @@
use std::{cell::RefCell, rc::Rc}; use std::path::PathBuf;
use acp_thread::AcpClientDelegate; use acp_thread::AcpThread;
use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams}; use agent_client_protocol as acp;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use collections::HashMap; use collections::HashMap;
use context_server::{ use context_server::types::{
listener::McpServer, CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
types::{ ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse, ToolResponseContent, ToolsCapabilities, requests,
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
ToolResponseContent, ToolsCapabilities, requests,
},
}; };
use gpui::{App, AsyncApp, Task}; use gpui::{App, AsyncApp, Entity, Task, WeakEntity};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::debug_panic;
use crate::claude::{ use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
McpServerConfig,
tools::{ClaudeTool, EditToolParams, ReadToolParams},
};
pub struct ClaudeMcpServer { pub struct ClaudeZedMcpServer {
server: McpServer, server: context_server::listener::McpServer,
} }
pub const SERVER_NAME: &str = "zed"; pub const SERVER_NAME: &str = "zed";
@ -52,17 +45,16 @@ enum PermissionToolBehavior {
Deny, Deny,
} }
impl ClaudeMcpServer { impl ClaudeZedMcpServer {
pub async fn new( pub async fn new(
delegate: watch::Receiver<Option<AcpClientDelegate>>, thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
cx: &AsyncApp, cx: &AsyncApp,
) -> Result<Self> { ) -> Result<Self> {
let mut mcp_server = McpServer::new(cx).await?; let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize); mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools); mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
mcp_server.handle_request::<requests::CallTool>(move |request, cx| { mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx) Self::handle_call_tool(request, thread_rx.clone(), cx)
}); });
Ok(Self { server: mcp_server }) Ok(Self { server: mcp_server })
@ -70,9 +62,7 @@ impl ClaudeMcpServer {
pub fn server_config(&self) -> Result<McpServerConfig> { pub fn server_config(&self) -> Result<McpServerConfig> {
let zed_path = std::env::current_exe() let zed_path = std::env::current_exe()
.context("finding current executable path for use in mcp_server")? .context("finding current executable path for use in mcp_server")?;
.to_string_lossy()
.to_string();
Ok(McpServerConfig { Ok(McpServerConfig {
command: zed_path, command: zed_path,
@ -152,22 +142,19 @@ impl ClaudeMcpServer {
fn handle_call_tool( fn handle_call_tool(
request: CallToolParams, request: CallToolParams,
mut delegate_watch: watch::Receiver<Option<AcpClientDelegate>>, mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
cx: &App, cx: &App,
) -> Task<Result<CallToolResponse>> { ) -> Task<Result<CallToolResponse>> {
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let Some(delegate) = delegate_watch.recv().await? else { let Some(thread) = thread_rx.recv().await?.upgrade() else {
debug_panic!("Sent None delegate"); anyhow::bail!("Thread closed");
anyhow::bail!("Server not available");
}; };
if request.name.as_str() == PERMISSION_TOOL { if request.name.as_str() == PERMISSION_TOOL {
let input = let input =
serde_json::from_value(request.arguments.context("Arguments required")?)?; serde_json::from_value(request.arguments.context("Arguments required")?)?;
let result = let result = Self::handle_permissions_tool_call(input, thread, cx).await?;
Self::handle_permissions_tool_call(input, delegate, tool_id_map, cx).await?;
Ok(CallToolResponse { Ok(CallToolResponse {
content: vec![ToolResponseContent::Text { content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&result)?, text: serde_json::to_string(&result)?,
@ -179,7 +166,7 @@ impl ClaudeMcpServer {
let input = let input =
serde_json::from_value(request.arguments.context("Arguments required")?)?; serde_json::from_value(request.arguments.context("Arguments required")?)?;
let content = Self::handle_read_tool_call(input, delegate, cx).await?; let content = Self::handle_read_tool_call(input, thread, cx).await?;
Ok(CallToolResponse { Ok(CallToolResponse {
content, content,
is_error: None, is_error: None,
@ -189,7 +176,7 @@ impl ClaudeMcpServer {
let input = let input =
serde_json::from_value(request.arguments.context("Arguments required")?)?; serde_json::from_value(request.arguments.context("Arguments required")?)?;
Self::handle_edit_tool_call(input, delegate, cx).await?; Self::handle_edit_tool_call(input, thread, cx).await?;
Ok(CallToolResponse { Ok(CallToolResponse {
content: vec![], content: vec![],
is_error: None, is_error: None,
@ -202,49 +189,46 @@ impl ClaudeMcpServer {
} }
fn handle_read_tool_call( fn handle_read_tool_call(
params: ReadToolParams, ReadToolParams {
delegate: AcpClientDelegate, abs_path,
offset,
limit,
}: ReadToolParams,
thread: Entity<AcpThread>,
cx: &AsyncApp, cx: &AsyncApp,
) -> Task<Result<Vec<ToolResponseContent>>> { ) -> Task<Result<Vec<ToolResponseContent>>> {
cx.foreground_executor().spawn(async move { cx.spawn(async move |cx| {
let response = delegate let content = thread
.read_text_file(ReadTextFileParams { .update(cx, |thread, cx| {
path: params.abs_path, thread.read_text_file(abs_path, offset, limit, false, cx)
line: params.offset, })?
limit: params.limit,
})
.await?; .await?;
Ok(vec![ToolResponseContent::Text { Ok(vec![ToolResponseContent::Text { text: content }])
text: response.content,
}])
}) })
} }
fn handle_edit_tool_call( fn handle_edit_tool_call(
params: EditToolParams, params: EditToolParams,
delegate: AcpClientDelegate, thread: Entity<AcpThread>,
cx: &AsyncApp, cx: &AsyncApp,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
cx.foreground_executor().spawn(async move { cx.spawn(async move |cx| {
let response = delegate let content = thread
.read_text_file_reusing_snapshot(ReadTextFileParams { .update(cx, |threads, cx| {
path: params.abs_path.clone(), threads.read_text_file(params.abs_path.clone(), None, None, true, cx)
line: None, })?
limit: None,
})
.await?; .await?;
let new_content = response.content.replace(&params.old_text, &params.new_text); let new_content = content.replace(&params.old_text, &params.new_text);
if new_content == response.content { if new_content == content {
return Err(anyhow::anyhow!("The old_text was not found in the content")); return Err(anyhow::anyhow!("The old_text was not found in the content"));
} }
delegate thread
.write_text_file(WriteTextFileParams { .update(cx, |threads, cx| {
path: params.abs_path, threads.write_text_file(params.abs_path, new_content, cx)
content: new_content, })?
})
.await?; .await?;
Ok(()) Ok(())
@ -253,44 +237,65 @@ impl ClaudeMcpServer {
fn handle_permissions_tool_call( fn handle_permissions_tool_call(
params: PermissionToolParams, params: PermissionToolParams,
delegate: AcpClientDelegate, thread: Entity<AcpThread>,
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
cx: &AsyncApp, cx: &AsyncApp,
) -> Task<Result<PermissionToolResponse>> { ) -> Task<Result<PermissionToolResponse>> {
cx.foreground_executor().spawn(async move { cx.spawn(async move |cx| {
let claude_tool = ClaudeTool::infer(&params.tool_name, params.input.clone()); let claude_tool = ClaudeTool::infer(&params.tool_name, params.input.clone());
let tool_call_id = match params.tool_use_id { let tool_call_id =
Some(tool_use_id) => tool_id_map acp::ToolCallId(params.tool_use_id.context("Tool ID required")?.into());
.borrow()
.get(&tool_use_id)
.cloned()
.context("Tool call ID not found")?,
None => delegate.push_tool_call(claude_tool.as_acp()).await?.id, let allow_option_id = acp::PermissionOptionId("allow".into());
}; let reject_option_id = acp::PermissionOptionId("reject".into());
let outcome = delegate let chosen_option = thread
.request_existing_tool_call_confirmation( .update(cx, |thread, cx| {
tool_call_id, thread.request_tool_call_permission(
claude_tool.confirmation(None), claude_tool.as_acp(tool_call_id),
) vec![
acp::PermissionOption {
id: allow_option_id.clone(),
label: "Allow".into(),
kind: acp::PermissionOptionKind::AllowOnce,
},
acp::PermissionOption {
id: reject_option_id,
label: "Reject".into(),
kind: acp::PermissionOptionKind::RejectOnce,
},
],
cx,
)
})?
.await?; .await?;
match outcome { if chosen_option == allow_option_id {
acp::ToolCallConfirmationOutcome::Allow Ok(PermissionToolResponse {
| acp::ToolCallConfirmationOutcome::AlwaysAllow
| acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
| acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse {
behavior: PermissionToolBehavior::Allow, behavior: PermissionToolBehavior::Allow,
updated_input: params.input, updated_input: params.input,
}), })
acp::ToolCallConfirmationOutcome::Reject } else {
| acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse { Ok(PermissionToolResponse {
behavior: PermissionToolBehavior::Deny, behavior: PermissionToolBehavior::Deny,
updated_input: params.input, updated_input: params.input,
}), })
} }
}) })
} }
} }
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct McpConfig {
pub mcp_servers: HashMap<String, McpServerConfig>,
}
#[derive(Serialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct McpServerConfig {
pub command: PathBuf,
pub args: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>,
}

View file

@ -1,6 +1,6 @@
use std::path::PathBuf; use std::path::PathBuf;
use agentic_coding_protocol::{self as acp, PushToolCallParams, ToolCallLocation}; use agent_client_protocol as acp;
use itertools::Itertools; use itertools::Itertools;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -115,51 +115,36 @@ impl ClaudeTool {
Self::Other { name, .. } => name.clone(), Self::Other { name, .. } => name.clone(),
} }
} }
pub fn content(&self) -> Vec<acp::ToolCallContent> {
pub fn content(&self) -> Option<acp::ToolCallContent> {
match &self { match &self {
Self::Other { input, .. } => Some(acp::ToolCallContent::Markdown { Self::Other { input, .. } => vec![
markdown: format!( format!(
"```json\n{}```", "```json\n{}```",
serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) serde_json::to_string_pretty(&input).unwrap_or("{}".to_string())
), )
}), .into(),
Self::Task(Some(params)) => Some(acp::ToolCallContent::Markdown { ],
markdown: params.prompt.clone(), Self::Task(Some(params)) => vec![params.prompt.clone().into()],
}), Self::NotebookRead(Some(params)) => {
Self::NotebookRead(Some(params)) => Some(acp::ToolCallContent::Markdown { vec![params.notebook_path.display().to_string().into()]
markdown: params.notebook_path.display().to_string(), }
}), Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()],
Self::NotebookEdit(Some(params)) => Some(acp::ToolCallContent::Markdown { Self::Terminal(Some(params)) => vec![
markdown: params.new_source.clone(), format!(
}),
Self::Terminal(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: format!(
"`{}`\n\n{}", "`{}`\n\n{}",
params.command, params.command,
params.description.as_deref().unwrap_or_default() params.description.as_deref().unwrap_or_default()
), )
}), .into(),
Self::ReadFile(Some(params)) => Some(acp::ToolCallContent::Markdown { ],
markdown: params.abs_path.display().to_string(), Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()],
}), Self::Ls(Some(params)) => vec![params.path.display().to_string().into()],
Self::Ls(Some(params)) => Some(acp::ToolCallContent::Markdown { Self::Glob(Some(params)) => vec![params.to_string().into()],
markdown: params.path.display().to_string(), Self::Grep(Some(params)) => vec![format!("`{params}`").into()],
}), Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()],
Self::Glob(Some(params)) => Some(acp::ToolCallContent::Markdown { Self::WebSearch(Some(params)) => vec![params.to_string().into()],
markdown: params.to_string(), Self::TodoWrite(Some(params)) => vec![
}), params
Self::Grep(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: format!("`{params}`"),
}),
Self::WebFetch(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: params.prompt.clone(),
}),
Self::WebSearch(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: params.to_string(),
}),
Self::TodoWrite(Some(params)) => Some(acp::ToolCallContent::Markdown {
markdown: params
.todos .todos
.iter() .iter()
.map(|todo| { .map(|todo| {
@ -174,34 +159,39 @@ impl ClaudeTool {
todo.content todo.content
) )
}) })
.join("\n"), .join("\n")
}), .into(),
Self::ExitPlanMode(Some(params)) => Some(acp::ToolCallContent::Markdown { ],
markdown: params.plan.clone(), Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()],
}), Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff {
Self::Edit(Some(params)) => Some(acp::ToolCallContent::Diff {
diff: acp::Diff { diff: acp::Diff {
path: params.abs_path.clone(), path: params.abs_path.clone(),
old_text: Some(params.old_text.clone()), old_text: Some(params.old_text.clone()),
new_text: params.new_text.clone(), new_text: params.new_text.clone(),
}, },
}), }],
Self::Write(Some(params)) => Some(acp::ToolCallContent::Diff { Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff {
diff: acp::Diff { diff: acp::Diff {
path: params.file_path.clone(), path: params.file_path.clone(),
old_text: None, old_text: None,
new_text: params.content.clone(), new_text: params.content.clone(),
}, },
}), }],
Self::MultiEdit(Some(params)) => { Self::MultiEdit(Some(params)) => {
// todo: show multiple edits in a multibuffer? // todo: show multiple edits in a multibuffer?
params.edits.first().map(|edit| acp::ToolCallContent::Diff { params
diff: acp::Diff { .edits
path: params.file_path.clone(), .first()
old_text: Some(edit.old_string.clone()), .map(|edit| {
new_text: edit.new_string.clone(), vec![acp::ToolCallContent::Diff {
}, diff: acp::Diff {
}) path: params.file_path.clone(),
old_text: Some(edit.old_string.clone()),
new_text: edit.new_string.clone(),
},
}]
})
.unwrap_or_default()
} }
Self::Task(None) Self::Task(None)
| Self::NotebookRead(None) | Self::NotebookRead(None)
@ -217,181 +207,80 @@ impl ClaudeTool {
| Self::ExitPlanMode(None) | Self::ExitPlanMode(None)
| Self::Edit(None) | Self::Edit(None)
| Self::Write(None) | Self::Write(None)
| Self::MultiEdit(None) => None, | Self::MultiEdit(None) => vec![],
} }
} }
pub fn icon(&self) -> acp::Icon { pub fn kind(&self) -> acp::ToolKind {
match self { match self {
Self::Task(_) => acp::Icon::Hammer, Self::Task(_) => acp::ToolKind::Think,
Self::NotebookRead(_) => acp::Icon::FileSearch, Self::NotebookRead(_) => acp::ToolKind::Read,
Self::NotebookEdit(_) => acp::Icon::Pencil, Self::NotebookEdit(_) => acp::ToolKind::Edit,
Self::Edit(_) => acp::Icon::Pencil, Self::Edit(_) => acp::ToolKind::Edit,
Self::MultiEdit(_) => acp::Icon::Pencil, Self::MultiEdit(_) => acp::ToolKind::Edit,
Self::Write(_) => acp::Icon::Pencil, Self::Write(_) => acp::ToolKind::Edit,
Self::ReadFile(_) => acp::Icon::FileSearch, Self::ReadFile(_) => acp::ToolKind::Read,
Self::Ls(_) => acp::Icon::Folder, Self::Ls(_) => acp::ToolKind::Search,
Self::Glob(_) => acp::Icon::FileSearch, Self::Glob(_) => acp::ToolKind::Search,
Self::Grep(_) => acp::Icon::Regex, Self::Grep(_) => acp::ToolKind::Search,
Self::Terminal(_) => acp::Icon::Terminal, Self::Terminal(_) => acp::ToolKind::Execute,
Self::WebSearch(_) => acp::Icon::Globe, Self::WebSearch(_) => acp::ToolKind::Search,
Self::WebFetch(_) => acp::Icon::Globe, Self::WebFetch(_) => acp::ToolKind::Fetch,
Self::TodoWrite(_) => acp::Icon::LightBulb, Self::TodoWrite(_) => acp::ToolKind::Think,
Self::ExitPlanMode(_) => acp::Icon::Hammer, Self::ExitPlanMode(_) => acp::ToolKind::Think,
Self::Other { .. } => acp::Icon::Hammer, Self::Other { .. } => acp::ToolKind::Other,
}
}
pub fn confirmation(&self, description: Option<String>) -> acp::ToolCallConfirmation {
match &self {
Self::Edit(_) | Self::Write(_) | Self::NotebookEdit(_) | Self::MultiEdit(_) => {
acp::ToolCallConfirmation::Edit { description }
}
Self::WebFetch(params) => acp::ToolCallConfirmation::Fetch {
urls: params
.as_ref()
.map(|p| vec![p.url.clone()])
.unwrap_or_default(),
description,
},
Self::Terminal(Some(BashToolParams {
description,
command,
..
})) => acp::ToolCallConfirmation::Execute {
command: command.clone(),
root_command: command.clone(),
description: description.clone(),
},
Self::ExitPlanMode(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {}", params.plan)
} else {
params.plan.clone()
},
},
Self::Task(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {}", params.description)
} else {
params.description.clone()
},
},
Self::Ls(Some(LsToolParams { path, .. }))
| Self::ReadFile(Some(ReadToolParams { abs_path: path, .. })) => {
let path = path.display();
acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {path}")
} else {
path.to_string()
},
}
}
Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
let path = notebook_path.display();
acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {path}")
} else {
path.to_string()
},
}
}
Self::Glob(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {params}")
} else {
params.to_string()
},
},
Self::Grep(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {params}")
} else {
params.to_string()
},
},
Self::WebSearch(Some(params)) => acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {params}")
} else {
params.to_string()
},
},
Self::TodoWrite(Some(params)) => {
let params = params.todos.iter().map(|todo| &todo.content).join(", ");
acp::ToolCallConfirmation::Other {
description: if let Some(description) = description {
format!("{description} {params}")
} else {
params
},
}
}
Self::Terminal(None)
| Self::Task(None)
| Self::NotebookRead(None)
| Self::ExitPlanMode(None)
| Self::Ls(None)
| Self::Glob(None)
| Self::Grep(None)
| Self::ReadFile(None)
| Self::WebSearch(None)
| Self::TodoWrite(None)
| Self::Other { .. } => acp::ToolCallConfirmation::Other {
description: description.unwrap_or("".to_string()),
},
} }
} }
pub fn locations(&self) -> Vec<acp::ToolCallLocation> { pub fn locations(&self) -> Vec<acp::ToolCallLocation> {
match &self { match &self {
Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![ToolCallLocation { Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation {
path: abs_path.clone(), path: abs_path.clone(),
line: None, line: None,
}], }],
Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => {
vec![ToolCallLocation { vec![acp::ToolCallLocation {
path: file_path.clone(),
line: None,
}]
}
Self::Write(Some(WriteToolParams { file_path, .. })) => {
vec![acp::ToolCallLocation {
path: file_path.clone(), path: file_path.clone(),
line: None, line: None,
}] }]
} }
Self::Write(Some(WriteToolParams { file_path, .. })) => vec![ToolCallLocation {
path: file_path.clone(),
line: None,
}],
Self::ReadFile(Some(ReadToolParams { Self::ReadFile(Some(ReadToolParams {
abs_path, offset, .. abs_path, offset, ..
})) => vec![ToolCallLocation { })) => vec![acp::ToolCallLocation {
path: abs_path.clone(), path: abs_path.clone(),
line: *offset, line: *offset,
}], }],
Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
vec![ToolCallLocation { vec![acp::ToolCallLocation {
path: notebook_path.clone(), path: notebook_path.clone(),
line: None, line: None,
}] }]
} }
Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => {
vec![ToolCallLocation { vec![acp::ToolCallLocation {
path: notebook_path.clone(), path: notebook_path.clone(),
line: None, line: None,
}] }]
} }
Self::Glob(Some(GlobToolParams { Self::Glob(Some(GlobToolParams {
path: Some(path), .. path: Some(path), ..
})) => vec![ToolCallLocation { })) => vec![acp::ToolCallLocation {
path: path.clone(), path: path.clone(),
line: None, line: None,
}], }],
Self::Ls(Some(LsToolParams { path, .. })) => vec![ToolCallLocation { Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation {
path: path.clone(), path: path.clone(),
line: None, line: None,
}], }],
Self::Grep(Some(GrepToolParams { Self::Grep(Some(GrepToolParams {
path: Some(path), .. path: Some(path), ..
})) => vec![ToolCallLocation { })) => vec![acp::ToolCallLocation {
path: PathBuf::from(path), path: PathBuf::from(path),
line: None, line: None,
}], }],
@ -414,11 +303,13 @@ impl ClaudeTool {
} }
} }
pub fn as_acp(&self) -> PushToolCallParams { pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall {
PushToolCallParams { acp::ToolCall {
id,
kind: self.kind(),
status: acp::ToolCallStatus::InProgress,
label: self.label(), label: self.label(),
content: self.content(), content: self.content(),
icon: self.icon(),
locations: self.locations(), locations: self.locations(),
} }
} }

View file

@ -1,10 +1,9 @@
use std::{path::Path, sync::Arc, time::Duration}; use std::{path::Path, sync::Arc, time::Duration};
use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings}; use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
use acp_thread::{ use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus, use agent_client_protocol as acp;
};
use agentic_coding_protocol as acp;
use futures::{FutureExt, StreamExt, channel::mpsc, select}; use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::{Entity, TestAppContext}; use gpui::{Entity, TestAppContext};
use indoc::indoc; use indoc::indoc;
@ -54,19 +53,25 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send( thread.send(
acp::SendUserMessageParams { vec![
chunks: vec![ acp::ContentBlock::Text(acp::TextContent {
acp::UserMessageChunk::Text { text: "Read the file ".into(),
text: "Read the file ".into(), annotations: None,
}, }),
acp::UserMessageChunk::Path { acp::ContentBlock::ResourceLink(acp::ResourceLink {
path: Path::new("foo.rs").into(), uri: "foo.rs".into(),
}, name: "foo.rs".into(),
acp::UserMessageChunk::Text { annotations: None,
text: " and tell me what the content of the println! is".into(), description: None,
}, mime_type: None,
], size: None,
}, title: None,
}),
acp::ContentBlock::Text(acp::TextContent {
text: " and tell me what the content of the println! is".into(),
annotations: None,
}),
],
cx, cx,
) )
}) })
@ -161,11 +166,8 @@ pub async fn test_tool_call_with_confirmation(
let tool_call_id = thread.read_with(cx, |thread, _cx| { let tool_call_id = thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
id, id,
status: content,
ToolCallStatus::WaitingForConfirmation { status: ToolCallStatus::WaitingForConfirmation { .. },
confirmation: ToolCallConfirmation::Execute { root_command, .. },
..
},
.. ..
}) = &thread }) = &thread
.entries() .entries()
@ -176,13 +178,18 @@ pub async fn test_tool_call_with_confirmation(
panic!(); panic!();
}; };
assert!(root_command.contains("touch")); assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
*id id.clone()
}); });
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx); thread.authorize_tool_call(
tool_call_id,
acp::PermissionOptionId("0".into()),
acp::PermissionOptionKind::AllowOnce,
cx,
);
assert!(thread.entries().iter().any(|entry| matches!( assert!(thread.entries().iter().any(|entry| matches!(
entry, entry,
@ -197,7 +204,7 @@ pub async fn test_tool_call_with_confirmation(
thread.read_with(cx, |thread, cx| { thread.read_with(cx, |thread, cx| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
content: Some(ToolCallContent::Markdown { markdown }), content,
status: ToolCallStatus::Allowed { .. }, status: ToolCallStatus::Allowed { .. },
.. ..
}) = thread }) = thread
@ -209,13 +216,10 @@ pub async fn test_tool_call_with_confirmation(
panic!(); panic!();
}; };
markdown.read_with(cx, |md, _cx| { assert!(
assert!( content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
md.source().contains("Hello"), "Expected content to contain 'Hello'"
r#"Expected '{}' to contain "Hello""#, );
md.source()
);
});
}); });
} }
@ -249,26 +253,20 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
thread.read_with(cx, |thread, _cx| { thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
id, id,
status: content,
ToolCallStatus::WaitingForConfirmation { status: ToolCallStatus::WaitingForConfirmation { .. },
confirmation: ToolCallConfirmation::Execute { root_command, .. },
..
},
.. ..
}) = &thread.entries()[first_tool_call_ix] }) = &thread.entries()[first_tool_call_ix]
else { else {
panic!("{:?}", thread.entries()[1]); panic!("{:?}", thread.entries()[1]);
}; };
assert!(root_command.contains("touch")); assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
*id id.clone()
}); });
thread let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
.update(cx, |thread, cx| thread.cancel(cx))
.await
.unwrap();
full_turn.await.unwrap(); full_turn.await.unwrap();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
@ -369,15 +367,16 @@ pub async fn new_test_thread(
current_dir: impl AsRef<Path>, current_dir: impl AsRef<Path>,
cx: &mut TestAppContext, cx: &mut TestAppContext,
) -> Entity<AcpThread> { ) -> Entity<AcpThread> {
let thread = cx let connection = cx
.update(|cx| server.new_thread(current_dir.as_ref(), &project, cx)) .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
.await .await
.unwrap(); .unwrap();
thread let thread = connection
.update(cx, |thread, _| thread.initialize()) .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
.await .await
.unwrap(); .unwrap();
thread thread
} }

View file

@ -1,9 +1,17 @@
use crate::stdio_agent_server::StdioAgentServer; use anyhow::anyhow;
use crate::{AgentServerCommand, AgentServerVersion}; use std::cell::RefCell;
use std::path::Path;
use std::rc::Rc;
use util::ResultExt as _;
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
use agentic_coding_protocol as acp_old;
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use gpui::{AsyncApp, Entity}; use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use ui::App;
use crate::AllAgentServersSettings; use crate::AllAgentServersSettings;
@ -12,7 +20,7 @@ pub struct Gemini;
const ACP_ARG: &str = "--experimental-acp"; const ACP_ARG: &str = "--experimental-acp";
impl StdioAgentServer for Gemini { impl AgentServer for Gemini {
fn name(&self) -> &'static str { fn name(&self) -> &'static str {
"Gemini" "Gemini"
} }
@ -25,14 +33,88 @@ impl StdioAgentServer for Gemini {
"Ask questions, edit files, run commands.\nBe specific for the best results." "Ask questions, edit files, run commands.\nBe specific for the best results."
} }
fn supports_always_allow(&self) -> bool {
true
}
fn logo(&self) -> ui::IconName { fn logo(&self) -> ui::IconName {
ui::IconName::AiGemini ui::IconName::AiGemini
} }
fn connect(
&self,
root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let root_dir = root_dir.to_path_buf();
let project = project.clone();
let this = self.clone();
let name = self.name();
cx.spawn(async move |cx| {
let command = this.command(&project, cx).await?;
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()?;
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let foreground_executor = cx.foreground_executor().clone();
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
stdin,
stdout,
move |fut| foreground_executor.spawn(fut).detach(),
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
});
let child_status = cx.background_spawn(async move {
let result = match child.status().await {
Err(e) => Err(anyhow!(e)),
Ok(result) if result.success() => Ok(()),
Ok(result) => {
if let Some(AgentServerVersion::Unsupported {
error_message,
upgrade_message,
upgrade_command,
}) = this.version(&command).await.log_err()
{
Err(anyhow!(LoadError::Unsupported {
error_message,
upgrade_message,
upgrade_command
}))
} else {
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
}
}
};
drop(io_task);
result
});
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
name,
connection,
child_status,
});
Ok(connection)
})
}
}
impl Gemini {
async fn command( async fn command(
&self, &self,
project: &Entity<Project>, project: &Entity<Project>,

View file

@ -1,119 +0,0 @@
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
use acp_thread::{AcpClientDelegate, AcpThread, LoadError};
use agentic_coding_protocol as acp;
use anyhow::{Result, anyhow};
use gpui::{App, AsyncApp, Entity, Task, prelude::*};
use project::Project;
use std::path::Path;
use util::ResultExt;
pub trait StdioAgentServer: Send + Clone {
fn logo(&self) -> ui::IconName;
fn name(&self) -> &'static str;
fn empty_state_headline(&self) -> &'static str;
fn empty_state_message(&self) -> &'static str;
fn supports_always_allow(&self) -> bool;
fn command(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> impl Future<Output = Result<AgentServerCommand>>;
fn version(
&self,
command: &AgentServerCommand,
) -> impl Future<Output = Result<AgentServerVersion>> + Send;
}
impl<T: StdioAgentServer + 'static> AgentServer for T {
fn name(&self) -> &'static str {
self.name()
}
fn empty_state_headline(&self) -> &'static str {
self.empty_state_headline()
}
fn empty_state_message(&self) -> &'static str {
self.empty_state_message()
}
fn logo(&self) -> ui::IconName {
self.logo()
}
fn supports_always_allow(&self) -> bool {
self.supports_always_allow()
}
fn new_thread(
&self,
root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
let root_dir = root_dir.to_path_buf();
let project = project.clone();
let this = self.clone();
let title = self.name().into();
cx.spawn(async move |cx| {
let command = this.command(&project, cx).await?;
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()?;
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
cx.new(|cx| {
let foreground_executor = cx.foreground_executor().clone();
let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
stdin,
stdout,
move |fut| foreground_executor.spawn(fut).detach(),
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
});
let child_status = cx.background_spawn(async move {
let result = match child.status().await {
Err(e) => Err(anyhow!(e)),
Ok(result) if result.success() => Ok(()),
Ok(result) => {
if let Some(AgentServerVersion::Unsupported {
error_message,
upgrade_message,
upgrade_command,
}) = this.version(&command).await.log_err()
{
Err(anyhow!(LoadError::Unsupported {
error_message,
upgrade_message,
upgrade_command
}))
} else {
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
}
}
};
drop(io_task);
result
});
AcpThread::new(connection, title, Some(child_status), project.clone(), cx)
})
})
}
}

View file

@ -17,10 +17,10 @@ test-support = ["gpui/test-support", "language/test-support"]
[dependencies] [dependencies]
acp_thread.workspace = true acp_thread.workspace = true
agent-client-protocol.workspace = true
agent.workspace = true agent.workspace = true
agentic-coding-protocol.workspace = true
agent_settings.workspace = true
agent_servers.workspace = true agent_servers.workspace = true
agent_settings.workspace = true
ai_onboarding.workspace = true ai_onboarding.workspace = true
anyhow.workspace = true anyhow.workspace = true
assistant_context.workspace = true assistant_context.workspace = true

View file

@ -1,4 +1,4 @@
use acp_thread::Plan; use acp_thread::{AgentConnection, Plan};
use agent_servers::AgentServer; use agent_servers::AgentServer;
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::BTreeMap; use std::collections::BTreeMap;
@ -7,7 +7,7 @@ use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use agentic_coding_protocol::{self as acp}; use agent_client_protocol as acp;
use assistant_tool::ActionLog; use assistant_tool::ActionLog;
use buffer_diff::BufferDiff; use buffer_diff::BufferDiff;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
@ -16,7 +16,6 @@ use editor::{
EditorStyle, MinimapVisibility, MultiBuffer, PathKey, EditorStyle, MinimapVisibility, MultiBuffer, PathKey,
}; };
use file_icons::FileIcons; use file_icons::FileIcons;
use futures::channel::oneshot;
use gpui::{ use gpui::{
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId, Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement,
@ -39,8 +38,7 @@ use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
use ::acp_thread::{ use ::acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff,
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent, LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
ToolCallId, ToolCallStatus,
}; };
use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet}; use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSet};
@ -64,12 +62,13 @@ pub struct AcpThreadView {
last_error: Option<Entity<Markdown>>, last_error: Option<Entity<Markdown>>,
list_state: ListState, list_state: ListState,
auth_task: Option<Task<()>>, auth_task: Option<Task<()>>,
expanded_tool_calls: HashSet<ToolCallId>, expanded_tool_calls: HashSet<acp::ToolCallId>,
expanded_thinking_blocks: HashSet<(usize, usize)>, expanded_thinking_blocks: HashSet<(usize, usize)>,
edits_expanded: bool, edits_expanded: bool,
plan_expanded: bool, plan_expanded: bool,
editor_expanded: bool, editor_expanded: bool,
message_history: Rc<RefCell<MessageHistory<acp::SendUserMessageParams>>>, message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>,
_cancel_task: Option<Task<()>>,
} }
enum ThreadState { enum ThreadState {
@ -82,22 +81,16 @@ enum ThreadState {
}, },
LoadError(LoadError), LoadError(LoadError),
Unauthenticated { Unauthenticated {
thread: Entity<AcpThread>, connection: Rc<dyn AgentConnection>,
}, },
} }
struct AlwaysAllowOption {
id: &'static str,
label: SharedString,
outcome: acp::ToolCallConfirmationOutcome,
}
impl AcpThreadView { impl AcpThreadView {
pub fn new( pub fn new(
agent: Rc<dyn AgentServer>, agent: Rc<dyn AgentServer>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
project: Entity<Project>, project: Entity<Project>,
message_history: Rc<RefCell<MessageHistory<acp::SendUserMessageParams>>>, message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>,
min_lines: usize, min_lines: usize,
max_lines: Option<usize>, max_lines: Option<usize>,
window: &mut Window, window: &mut Window,
@ -191,6 +184,7 @@ impl AcpThreadView {
plan_expanded: false, plan_expanded: false,
editor_expanded: false, editor_expanded: false,
message_history, message_history,
_cancel_task: None,
} }
} }
@ -208,9 +202,9 @@ impl AcpThreadView {
.map(|worktree| worktree.read(cx).abs_path()) .map(|worktree| worktree.read(cx).abs_path())
.unwrap_or_else(|| paths::home_dir().as_path().into()); .unwrap_or_else(|| paths::home_dir().as_path().into());
let task = agent.new_thread(&root_dir, &project, cx); let connect_task = agent.connect(&root_dir, &project, cx);
let load_task = cx.spawn_in(window, async move |this, cx| { let load_task = cx.spawn_in(window, async move |this, cx| {
let thread = match task.await { let connection = match connect_task.await {
Ok(thread) => thread, Ok(thread) => thread,
Err(err) => { Err(err) => {
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
@ -222,48 +216,30 @@ impl AcpThreadView {
} }
}; };
let init_response = async { let result = match connection
let resp = thread .clone()
.read_with(cx, |thread, _cx| thread.initialize())? .new_thread(project.clone(), &root_dir, cx)
.await?; .await
anyhow::Ok(resp) {
};
let result = match init_response.await {
Err(e) => { Err(e) => {
let mut cx = cx.clone(); let mut cx = cx.clone();
if e.downcast_ref::<oneshot::Canceled>().is_some() { if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
let child_status = thread this.update(&mut cx, |this, cx| {
.update(&mut cx, |thread, _| thread.child_status()) this.thread_state = ThreadState::Unauthenticated { connection };
.ok() cx.notify();
.flatten(); })
if let Some(child_status) = child_status { .ok();
match child_status.await { return;
Ok(_) => Err(e),
Err(e) => Err(e),
}
} else {
Err(e)
}
} else { } else {
Err(e) Err(e)
} }
} }
Ok(response) => { Ok(session_id) => Ok(session_id),
if !response.is_authenticated {
this.update(cx, |this, _| {
this.thread_state = ThreadState::Unauthenticated { thread };
})
.ok();
return;
};
Ok(())
}
}; };
this.update_in(cx, |this, window, cx| { this.update_in(cx, |this, window, cx| {
match result { match result {
Ok(()) => { Ok(thread) => {
let thread_subscription = let thread_subscription =
cx.subscribe_in(&thread, window, Self::handle_thread_event); cx.subscribe_in(&thread, window, Self::handle_thread_event);
@ -305,10 +281,10 @@ impl AcpThreadView {
pub fn thread(&self) -> Option<&Entity<AcpThread>> { pub fn thread(&self) -> Option<&Entity<AcpThread>> {
match &self.thread_state { match &self.thread_state {
ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { ThreadState::Ready { thread, .. } => Some(thread),
Some(thread) ThreadState::Unauthenticated { .. }
} | ThreadState::Loading { .. }
ThreadState::Loading { .. } | ThreadState::LoadError(..) => None, | ThreadState::LoadError(..) => None,
} }
} }
@ -325,7 +301,7 @@ impl AcpThreadView {
self.last_error.take(); self.last_error.take();
if let Some(thread) = self.thread() { if let Some(thread) = self.thread() {
thread.update(cx, |thread, cx| thread.cancel(cx)).detach(); self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
} }
} }
@ -362,7 +338,7 @@ impl AcpThreadView {
self.last_error.take(); self.last_error.take();
let mut ix = 0; let mut ix = 0;
let mut chunks: Vec<acp::UserMessageChunk> = Vec::new(); let mut chunks: Vec<acp::ContentBlock> = Vec::new();
let project = self.project.clone(); let project = self.project.clone();
self.message_editor.update(cx, |editor, cx| { self.message_editor.update(cx, |editor, cx| {
let text = editor.text(cx); let text = editor.text(cx);
@ -374,12 +350,19 @@ impl AcpThreadView {
{ {
let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot); let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot);
if crease_range.start > ix { if crease_range.start > ix {
chunks.push(acp::UserMessageChunk::Text { chunks.push(text[ix..crease_range.start].into());
text: text[ix..crease_range.start].to_string(),
});
} }
if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) { if let Some(abs_path) = project.read(cx).absolute_path(&project_path, cx) {
chunks.push(acp::UserMessageChunk::Path { path: abs_path }); let path_str = abs_path.display().to_string();
chunks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink {
uri: path_str.clone(),
name: path_str,
annotations: None,
description: None,
mime_type: None,
size: None,
title: None,
}));
} }
ix = crease_range.end; ix = crease_range.end;
} }
@ -388,9 +371,7 @@ impl AcpThreadView {
if ix < text.len() { if ix < text.len() {
let last_chunk = text[ix..].trim(); let last_chunk = text[ix..].trim();
if !last_chunk.is_empty() { if !last_chunk.is_empty() {
chunks.push(acp::UserMessageChunk::Text { chunks.push(last_chunk.into());
text: last_chunk.into(),
});
} }
} }
}) })
@ -401,8 +382,7 @@ impl AcpThreadView {
} }
let Some(thread) = self.thread() else { return }; let Some(thread) = self.thread() else { return };
let message = acp::SendUserMessageParams { chunks }; let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx));
let task = thread.update(cx, |thread, cx| thread.send(message.clone(), cx));
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let result = task.await; let result = task.await;
@ -424,7 +404,7 @@ impl AcpThreadView {
editor.remove_creases(mention_set.lock().drain(), cx) editor.remove_creases(mention_set.lock().drain(), cx)
}); });
self.message_history.borrow_mut().push(message); self.message_history.borrow_mut().push(chunks);
} }
fn previous_history_message( fn previous_history_message(
@ -490,7 +470,7 @@ impl AcpThreadView {
message_editor: Entity<Editor>, message_editor: Entity<Editor>,
mention_set: Arc<Mutex<MentionSet>>, mention_set: Arc<Mutex<MentionSet>>,
project: Entity<Project>, project: Entity<Project>,
message: Option<&acp::SendUserMessageParams>, message: Option<&Vec<acp::ContentBlock>>,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> bool { ) -> bool {
@ -503,18 +483,19 @@ impl AcpThreadView {
let mut text = String::new(); let mut text = String::new();
let mut mentions = Vec::new(); let mut mentions = Vec::new();
for chunk in &message.chunks { for chunk in message {
match chunk { match chunk {
acp::UserMessageChunk::Text { text: chunk } => { acp::ContentBlock::Text(text_content) => {
text.push_str(&chunk); text.push_str(&text_content.text);
} }
acp::UserMessageChunk::Path { path } => { acp::ContentBlock::ResourceLink(resource_link) => {
let path = Path::new(&resource_link.uri);
let start = text.len(); let start = text.len();
let content = MentionPath::new(path).to_string(); let content = MentionPath::new(&path).to_string();
text.push_str(&content); text.push_str(&content);
let end = text.len(); let end = text.len();
if let Some(project_path) = if let Some(project_path) =
project.read(cx).project_path_for_absolute_path(path, cx) project.read(cx).project_path_for_absolute_path(&path, cx)
{ {
let filename: SharedString = path let filename: SharedString = path
.file_name() .file_name()
@ -525,6 +506,9 @@ impl AcpThreadView {
mentions.push((start..end, project_path, filename)); mentions.push((start..end, project_path, filename));
} }
} }
acp::ContentBlock::Image(_)
| acp::ContentBlock::Audio(_)
| acp::ContentBlock::Resource(_) => {}
} }
} }
@ -590,71 +574,79 @@ impl AcpThreadView {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let Some(multibuffer) = self.entry_diff_multibuffer(entry_ix, cx) else { let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else {
return; return;
}; };
if self.diff_editors.contains_key(&multibuffer.entity_id()) { let multibuffers = multibuffers.collect::<Vec<_>>();
return;
}
let editor = cx.new(|cx| { for multibuffer in multibuffers {
let mut editor = Editor::new( if self.diff_editors.contains_key(&multibuffer.entity_id()) {
EditorMode::Full { return;
scale_ui_elements_with_buffer_font_size: false, }
show_active_line_background: false,
sized_by_content: true, let editor = cx.new(|cx| {
}, let mut editor = Editor::new(
multibuffer.clone(), EditorMode::Full {
None, scale_ui_elements_with_buffer_font_size: false,
window, show_active_line_background: false,
cx, sized_by_content: true,
); },
editor.set_show_gutter(false, cx); multibuffer.clone(),
editor.disable_inline_diagnostics(); None,
editor.disable_expand_excerpt_buttons(cx); window,
editor.set_show_vertical_scrollbar(false, cx); cx,
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); );
editor.set_soft_wrap_mode(SoftWrap::None, cx); editor.set_show_gutter(false, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true); editor.disable_inline_diagnostics();
editor.set_show_indent_guides(false, cx); editor.disable_expand_excerpt_buttons(cx);
editor.set_read_only(true); editor.set_show_vertical_scrollbar(false, cx);
editor.set_show_breakpoints(false, cx); editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
editor.set_show_code_actions(false, cx); editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.set_show_git_diff_gutter(false, cx); editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_expand_all_diff_hunks(cx); editor.set_show_indent_guides(false, cx);
editor.set_text_style_refinement(TextStyleRefinement { editor.set_read_only(true);
font_size: Some( editor.set_show_breakpoints(false, cx);
TextSize::Small editor.set_show_code_actions(false, cx);
.rems(cx) editor.set_show_git_diff_gutter(false, cx);
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx)) editor.set_expand_all_diff_hunks(cx);
.into(), editor.set_text_style_refinement(TextStyleRefinement {
), font_size: Some(
..Default::default() TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
});
editor
}); });
editor let entity_id = multibuffer.entity_id();
}); cx.observe_release(&multibuffer, move |this, _, _| {
let entity_id = multibuffer.entity_id(); this.diff_editors.remove(&entity_id);
cx.observe_release(&multibuffer, move |this, _, _| { })
this.diff_editors.remove(&entity_id); .detach();
})
.detach();
self.diff_editors.insert(entity_id, editor); self.diff_editors.insert(entity_id, editor);
}
} }
fn entry_diff_multibuffer(&self, entry_ix: usize, cx: &App) -> Option<Entity<MultiBuffer>> { fn entry_diff_multibuffers(
&self,
entry_ix: usize,
cx: &App,
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
let entry = self.thread()?.read(cx).entries().get(entry_ix)?; let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
entry.diff().map(|diff| diff.multibuffer.clone()) Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
} }
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let Some(thread) = self.thread().cloned() else { let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
return; return;
}; };
self.last_error.take(); self.last_error.take();
let authenticate = thread.read(cx).authenticate(); let authenticate = connection.authenticate(cx);
self.auth_task = Some(cx.spawn_in(window, { self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone(); let project = self.project.clone();
let agent = self.agent.clone(); let agent = self.agent.clone();
@ -684,15 +676,16 @@ impl AcpThreadView {
fn authorize_tool_call( fn authorize_tool_call(
&mut self, &mut self,
id: ToolCallId, tool_call_id: acp::ToolCallId,
outcome: acp::ToolCallConfirmationOutcome, option_id: acp::PermissionOptionId,
option_kind: acp::PermissionOptionKind,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let Some(thread) = self.thread() else { let Some(thread) = self.thread() else {
return; return;
}; };
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
thread.authorize_tool_call(id, outcome, cx); thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx);
}); });
cx.notify(); cx.notify();
} }
@ -719,10 +712,12 @@ impl AcpThreadView {
.border_1() .border_1()
.border_color(cx.theme().colors().border) .border_color(cx.theme().colors().border)
.text_xs() .text_xs()
.child(self.render_markdown( .children(message.content.markdown().map(|md| {
message.content.clone(), self.render_markdown(
user_message_markdown_style(window, cx), md.clone(),
)), user_message_markdown_style(window, cx),
)
})),
) )
.into_any(), .into_any(),
AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => { AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) => {
@ -730,20 +725,28 @@ impl AcpThreadView {
let message_body = v_flex() let message_body = v_flex()
.w_full() .w_full()
.gap_2p5() .gap_2p5()
.children(chunks.iter().enumerate().map(|(chunk_ix, chunk)| { .children(chunks.iter().enumerate().filter_map(
match chunk { |(chunk_ix, chunk)| match chunk {
AssistantMessageChunk::Text { chunk } => self AssistantMessageChunk::Message { block } => {
.render_markdown(chunk.clone(), style.clone()) block.markdown().map(|md| {
.into_any_element(), self.render_markdown(md.clone(), style.clone())
AssistantMessageChunk::Thought { chunk } => self.render_thinking_block( .into_any_element()
index, })
chunk_ix, }
chunk.clone(), AssistantMessageChunk::Thought { block } => {
window, block.markdown().map(|md| {
cx, self.render_thinking_block(
), index,
} chunk_ix,
})) md.clone(),
window,
cx,
)
.into_any_element()
})
}
},
))
.into_any(); .into_any();
v_flex() v_flex()
@ -871,7 +874,7 @@ impl AcpThreadView {
let status_icon = match &tool_call.status { let status_icon = match &tool_call.status {
ToolCallStatus::WaitingForConfirmation { .. } => None, ToolCallStatus::WaitingForConfirmation { .. } => None,
ToolCallStatus::Allowed { ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Running, status: acp::ToolCallStatus::InProgress,
.. ..
} => Some( } => Some(
Icon::new(IconName::ArrowCircle) Icon::new(IconName::ArrowCircle)
@ -885,13 +888,13 @@ impl AcpThreadView {
.into_any(), .into_any(),
), ),
ToolCallStatus::Allowed { ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Finished, status: acp::ToolCallStatus::Completed,
.. ..
} => None, } => None,
ToolCallStatus::Rejected ToolCallStatus::Rejected
| ToolCallStatus::Canceled | ToolCallStatus::Canceled
| ToolCallStatus::Allowed { | ToolCallStatus::Allowed {
status: acp::ToolCallStatus::Error, status: acp::ToolCallStatus::Failed,
.. ..
} => Some( } => Some(
Icon::new(IconName::X) Icon::new(IconName::X)
@ -909,34 +912,9 @@ impl AcpThreadView {
.any(|content| matches!(content, ToolCallContent::Diff { .. })), .any(|content| matches!(content, ToolCallContent::Diff { .. })),
}; };
let is_collapsible = tool_call.content.is_some() && !needs_confirmation; let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation;
let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id); let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id);
let content = if is_open {
match &tool_call.status {
ToolCallStatus::WaitingForConfirmation { confirmation, .. } => {
Some(self.render_tool_call_confirmation(
tool_call.id,
confirmation,
tool_call.content.as_ref(),
window,
cx,
))
}
ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {
tool_call.content.as_ref().map(|content| {
div()
.py_1p5()
.child(self.render_tool_call_content(content, window, cx))
.into_any_element()
})
}
ToolCallStatus::Rejected => None,
}
} else {
None
};
v_flex() v_flex()
.when(needs_confirmation, |this| { .when(needs_confirmation, |this| {
this.rounded_lg() this.rounded_lg()
@ -976,9 +954,17 @@ impl AcpThreadView {
}) })
.gap_1p5() .gap_1p5()
.child( .child(
Icon::new(tool_call.icon) Icon::new(match tool_call.kind {
.size(IconSize::Small) acp::ToolKind::Read => IconName::ToolRead,
.color(Color::Muted), acp::ToolKind::Edit => IconName::ToolPencil,
acp::ToolKind::Search => IconName::ToolSearch,
acp::ToolKind::Execute => IconName::ToolTerminal,
acp::ToolKind::Think => IconName::ToolBulb,
acp::ToolKind::Fetch => IconName::ToolWeb,
acp::ToolKind::Other => IconName::ToolHammer,
})
.size(IconSize::Small)
.color(Color::Muted),
) )
.child(if tool_call.locations.len() == 1 { .child(if tool_call.locations.len() == 1 {
let name = tool_call.locations[0] let name = tool_call.locations[0]
@ -1023,16 +1009,16 @@ impl AcpThreadView {
.gap_0p5() .gap_0p5()
.when(is_collapsible, |this| { .when(is_collapsible, |this| {
this.child( this.child(
Disclosure::new(("expand", tool_call.id.0), is_open) Disclosure::new(("expand", entry_ix), is_open)
.opened_icon(IconName::ChevronUp) .opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown) .closed_icon(IconName::ChevronDown)
.on_click(cx.listener({ .on_click(cx.listener({
let id = tool_call.id; let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| { move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open { if is_open {
this.expanded_tool_calls.remove(&id); this.expanded_tool_calls.remove(&id);
} else { } else {
this.expanded_tool_calls.insert(id); this.expanded_tool_calls.insert(id.clone());
} }
cx.notify(); cx.notify();
} }
@ -1042,12 +1028,12 @@ impl AcpThreadView {
.children(status_icon), .children(status_icon),
) )
.on_click(cx.listener({ .on_click(cx.listener({
let id = tool_call.id; let id = tool_call.id.clone();
move |this: &mut Self, _, _, cx: &mut Context<Self>| { move |this: &mut Self, _, _, cx: &mut Context<Self>| {
if is_open { if is_open {
this.expanded_tool_calls.remove(&id); this.expanded_tool_calls.remove(&id);
} else { } else {
this.expanded_tool_calls.insert(id); this.expanded_tool_calls.insert(id.clone());
} }
cx.notify(); cx.notify();
} }
@ -1055,7 +1041,7 @@ impl AcpThreadView {
) )
.when(is_open, |this| { .when(is_open, |this| {
this.child( this.child(
div() v_flex()
.text_xs() .text_xs()
.when(is_collapsible, |this| { .when(is_collapsible, |this| {
this.mt_1() this.mt_1()
@ -1064,7 +1050,44 @@ impl AcpThreadView {
.bg(cx.theme().colors().editor_background) .bg(cx.theme().colors().editor_background)
.rounded_lg() .rounded_lg()
}) })
.children(content), .map(|this| {
if is_open {
match &tool_call.status {
ToolCallStatus::WaitingForConfirmation { options, .. } => this
.children(tool_call.content.iter().map(|content| {
div()
.py_1p5()
.child(
self.render_tool_call_content(
content, window, cx,
),
)
.into_any_element()
}))
.child(self.render_permission_buttons(
options,
entry_ix,
tool_call.id.clone(),
cx,
)),
ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {
this.children(tool_call.content.iter().map(|content| {
div()
.py_1p5()
.child(
self.render_tool_call_content(
content, window, cx,
),
)
.into_any_element()
}))
}
ToolCallStatus::Rejected => this,
}
} else {
this
}
}),
) )
}) })
} }
@ -1076,14 +1099,20 @@ impl AcpThreadView {
cx: &Context<Self>, cx: &Context<Self>,
) -> AnyElement { ) -> AnyElement {
match content { match content {
ToolCallContent::Markdown { markdown } => { ToolCallContent::ContentBlock { content } => {
div() if let Some(md) = content.markdown() {
.p_2() div()
.child(self.render_markdown( .p_2()
markdown.clone(), .child(
default_markdown_style(false, window, cx), self.render_markdown(
)) md.clone(),
.into_any_element() default_markdown_style(false, window, cx),
),
)
.into_any_element()
} else {
Empty.into_any_element()
}
} }
ToolCallContent::Diff { ToolCallContent::Diff {
diff: Diff { multibuffer, .. }, diff: Diff { multibuffer, .. },
@ -1092,223 +1121,53 @@ impl AcpThreadView {
} }
} }
fn render_tool_call_confirmation( fn render_permission_buttons(
&self, &self,
tool_call_id: ToolCallId, options: &[acp::PermissionOption],
confirmation: &ToolCallConfirmation, entry_ix: usize,
content: Option<&ToolCallContent>, tool_call_id: acp::ToolCallId,
window: &Window,
cx: &Context<Self>,
) -> AnyElement {
let confirmation_container = v_flex().mt_1().py_1p5();
match confirmation {
ToolCallConfirmation::Edit { description } => confirmation_container
.child(
div()
.px_2()
.children(description.clone().map(|description| {
self.render_markdown(
description,
default_markdown_style(false, window, cx),
)
})),
)
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[AlwaysAllowOption {
id: "always_allow",
label: "Always Allow Edits".into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
}],
tool_call_id,
cx,
))
.into_any(),
ToolCallConfirmation::Execute {
command,
root_command,
description,
} => confirmation_container
.child(v_flex().px_2().pb_1p5().child(command.clone()).children(
description.clone().map(|description| {
self.render_markdown(description, default_markdown_style(false, window, cx))
.on_url_click({
let workspace = self.workspace.clone();
move |text, window, cx| {
Self::open_link(text, &workspace, window, cx);
}
})
}),
))
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[AlwaysAllowOption {
id: "always_allow",
label: format!("Always Allow {root_command}").into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
}],
tool_call_id,
cx,
))
.into_any(),
ToolCallConfirmation::Mcp {
server_name,
tool_name: _,
tool_display_name,
description,
} => confirmation_container
.child(
v_flex()
.px_2()
.pb_1p5()
.child(format!("{server_name} - {tool_display_name}"))
.children(description.clone().map(|description| {
self.render_markdown(
description,
default_markdown_style(false, window, cx),
)
})),
)
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[
AlwaysAllowOption {
id: "always_allow_server",
label: format!("Always Allow {server_name}").into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
},
AlwaysAllowOption {
id: "always_allow_tool",
label: format!("Always Allow {tool_display_name}").into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowTool,
},
],
tool_call_id,
cx,
))
.into_any(),
ToolCallConfirmation::Fetch { description, urls } => confirmation_container
.child(
v_flex()
.px_2()
.pb_1p5()
.gap_1()
.children(urls.iter().map(|url| {
h_flex().child(
Button::new(url.clone(), url)
.icon(IconName::ArrowUpRight)
.icon_color(Color::Muted)
.icon_size(IconSize::XSmall)
.on_click({
let url = url.clone();
move |_, _, cx| cx.open_url(&url)
}),
)
}))
.children(description.clone().map(|description| {
self.render_markdown(
description,
default_markdown_style(false, window, cx),
)
})),
)
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[AlwaysAllowOption {
id: "always_allow",
label: "Always Allow".into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
}],
tool_call_id,
cx,
))
.into_any(),
ToolCallConfirmation::Other { description } => confirmation_container
.child(v_flex().px_2().pb_1p5().child(self.render_markdown(
description.clone(),
default_markdown_style(false, window, cx),
)))
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
.child(self.render_confirmation_buttons(
&[AlwaysAllowOption {
id: "always_allow",
label: "Always Allow".into(),
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
}],
tool_call_id,
cx,
))
.into_any(),
}
}
fn render_confirmation_buttons(
&self,
always_allow_options: &[AlwaysAllowOption],
tool_call_id: ToolCallId,
cx: &Context<Self>, cx: &Context<Self>,
) -> Div { ) -> Div {
h_flex() h_flex()
.pt_1p5() .py_1p5()
.px_1p5() .px_1p5()
.gap_1() .gap_1()
.justify_end() .justify_end()
.border_t_1() .border_t_1()
.border_color(self.tool_card_border_color(cx)) .border_color(self.tool_card_border_color(cx))
.when(self.agent.supports_always_allow(), |this| { .children(options.iter().map(|option| {
this.children(always_allow_options.into_iter().map(|always_allow_option| { let option_id = SharedString::from(option.id.0.clone());
let outcome = always_allow_option.outcome; Button::new((option_id, entry_ix), option.label.clone())
Button::new( .map(|this| match option.kind {
(always_allow_option.id, tool_call_id.0), acp::PermissionOptionKind::AllowOnce => {
always_allow_option.label.clone(), this.icon(IconName::Check).icon_color(Color::Success)
) }
.icon(IconName::CheckDouble) acp::PermissionOptionKind::AllowAlways => {
this.icon(IconName::CheckDouble).icon_color(Color::Success)
}
acp::PermissionOptionKind::RejectOnce => {
this.icon(IconName::X).icon_color(Color::Error)
}
acp::PermissionOptionKind::RejectAlways => {
this.icon(IconName::X).icon_color(Color::Error)
}
})
.icon_position(IconPosition::Start) .icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall) .icon_size(IconSize::XSmall)
.icon_color(Color::Success)
.on_click(cx.listener({ .on_click(cx.listener({
let id = tool_call_id; let tool_call_id = tool_call_id.clone();
let option_id = option.id.clone();
let option_kind = option.kind;
move |this, _, _, cx| { move |this, _, _, cx| {
this.authorize_tool_call(id, outcome, cx); this.authorize_tool_call(
tool_call_id.clone(),
option_id.clone(),
option_kind,
cx,
);
} }
})) }))
})) }))
})
.child(
Button::new(("allow", tool_call_id.0), "Allow")
.icon(IconName::Check)
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon_color(Color::Success)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Allow,
cx,
);
}
})),
)
.child(
Button::new(("reject", tool_call_id.0), "Reject")
.icon(IconName::X)
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon_color(Color::Error)
.on_click(cx.listener({
let id = tool_call_id;
move |this, _, _, cx| {
this.authorize_tool_call(
id,
acp::ToolCallConfirmationOutcome::Reject,
cx,
);
}
})),
)
} }
fn render_diff_editor(&self, multibuffer: &Entity<MultiBuffer>) -> AnyElement { fn render_diff_editor(&self, multibuffer: &Entity<MultiBuffer>) -> AnyElement {
@ -2245,12 +2104,11 @@ impl AcpThreadView {
.languages .languages
.language_for_name("Markdown"); .language_for_name("Markdown");
let (thread_summary, markdown) = match &self.thread_state { let (thread_summary, markdown) = if let Some(thread) = self.thread() {
ThreadState::Ready { thread, .. } | ThreadState::Unauthenticated { thread } => { let thread = thread.read(cx);
let thread = thread.read(cx); (thread.title().to_string(), thread.to_markdown(cx))
(thread.title().to_string(), thread.to_markdown(cx)) } else {
} return Task::ready(Ok(()));
ThreadState::Loading { .. } | ThreadState::LoadError(..) => return Task::ready(Ok(())),
}; };
window.spawn(cx, async move |cx| { window.spawn(cx, async move |cx| {

View file

@ -1506,8 +1506,7 @@ impl AgentDiff {
.read(cx) .read(cx)
.entries() .entries()
.last() .last()
.and_then(|entry| entry.diff()) .map_or(false, |entry| entry.diffs().next().is_some())
.is_some()
{ {
self.update_reviewing_editors(workspace, window, cx); self.update_reviewing_editors(workspace, window, cx);
} }
@ -1517,8 +1516,7 @@ impl AgentDiff {
.read(cx) .read(cx)
.entries() .entries()
.get(*ix) .get(*ix)
.and_then(|entry| entry.diff()) .map_or(false, |entry| entry.diffs().next().is_some())
.is_some()
{ {
self.update_reviewing_editors(workspace, window, cx); self.update_reviewing_editors(workspace, window, cx);
} }

View file

@ -440,7 +440,7 @@ pub struct AgentPanel {
local_timezone: UtcOffset, local_timezone: UtcOffset,
active_view: ActiveView, active_view: ActiveView,
acp_message_history: acp_message_history:
Rc<RefCell<crate::acp::MessageHistory<agentic_coding_protocol::SendUserMessageParams>>>, Rc<RefCell<crate::acp::MessageHistory<Vec<agent_client_protocol::ContentBlock>>>>,
previous_view: Option<ActiveView>, previous_view: Option<ActiveView>,
history_store: Entity<HistoryStore>, history_store: Entity<HistoryStore>,
history: Entity<ThreadHistory>, history: Entity<ThreadHistory>,

View file

@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use collections::HashMap; use collections::HashMap;
use futures::{FutureExt, StreamExt, channel::oneshot, select}; use futures::{FutureExt, StreamExt, channel::oneshot, future, select};
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task}; use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
use parking_lot::Mutex; use parking_lot::Mutex;
use postage::barrier; use postage::barrier;
@ -10,15 +10,19 @@ use smol::channel;
use std::{ use std::{
fmt, fmt,
path::PathBuf, path::PathBuf,
pin::pin,
sync::{ sync::{
Arc, Arc,
atomic::{AtomicI32, Ordering::SeqCst}, atomic::{AtomicI32, Ordering::SeqCst},
}, },
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use util::TryFutureExt; use util::{ResultExt, TryFutureExt};
use crate::transport::{StdioTransport, Transport}; use crate::{
transport::{StdioTransport, Transport},
types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
};
const JSON_RPC_VERSION: &str = "2.0"; const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60); const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
@ -32,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>; type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>; type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
@ -78,6 +83,15 @@ pub struct Request<'a, T> {
pub params: T, pub params: T,
} }
#[derive(Serialize, Deserialize)]
pub struct AnyRequest<'a> {
pub jsonrpc: &'a str,
pub id: RequestId,
pub method: &'a str,
#[serde(skip_serializing_if = "is_null_value")]
pub params: Option<&'a RawValue>,
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct AnyResponse<'a> { struct AnyResponse<'a> {
jsonrpc: &'a str, jsonrpc: &'a str,
@ -176,15 +190,23 @@ impl Client {
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers = let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
let receive_input_task = cx.spawn({ let receive_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone(); let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone(); let response_handlers = response_handlers.clone();
let request_handlers = request_handlers.clone();
let transport = transport.clone(); let transport = transport.clone();
async move |cx| { async move |cx| {
Self::handle_input(transport, notification_handlers, response_handlers, cx) Self::handle_input(
.log_err() transport,
.await notification_handlers,
request_handlers,
response_handlers,
cx,
)
.log_err()
.await
} }
}); });
let receive_err_task = cx.spawn({ let receive_err_task = cx.spawn({
@ -230,13 +252,24 @@ impl Client {
async fn handle_input( async fn handle_input(
transport: Arc<dyn Transport>, transport: Arc<dyn Transport>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>, notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>, response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut receiver = transport.receive(); let mut receiver = transport.receive();
while let Some(message) = receiver.next().await { while let Some(message) = receiver.next().await {
if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) { log::trace!("recv: {}", &message);
if let Ok(request) = serde_json::from_str::<AnyRequest>(&message) {
let mut request_handlers = request_handlers.lock();
if let Some(handler) = request_handlers.get_mut(request.method) {
handler(
request.id,
request.params.unwrap_or(RawValue::NULL),
cx.clone(),
);
}
} else if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
if let Some(handlers) = response_handlers.lock().as_mut() { if let Some(handlers) = response_handlers.lock().as_mut() {
if let Some(handler) = handlers.remove(&response.id) { if let Some(handler) = handlers.remove(&response.id) {
handler(Ok(message.to_string())); handler(Ok(message.to_string()));
@ -247,6 +280,8 @@ impl Client {
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
handler(notification.params.unwrap_or(Value::Null), cx.clone()); handler(notification.params.unwrap_or(Value::Null), cx.clone());
} }
} else {
log::error!("Unhandled JSON from context_server: {}", message);
} }
} }
@ -294,6 +329,24 @@ impl Client {
&self, &self,
method: &str, method: &str,
params: impl Serialize, params: impl Serialize,
) -> Result<T> {
self.request_impl(method, params, None).await
}
pub async fn cancellable_request<T: DeserializeOwned>(
&self,
method: &str,
params: impl Serialize,
cancel_rx: oneshot::Receiver<()>,
) -> Result<T> {
self.request_impl(method, params, Some(cancel_rx)).await
}
pub async fn request_impl<T: DeserializeOwned>(
&self,
method: &str,
params: impl Serialize,
cancel_rx: Option<oneshot::Receiver<()>>,
) -> Result<T> { ) -> Result<T> {
let id = self.next_id.fetch_add(1, SeqCst); let id = self.next_id.fetch_add(1, SeqCst);
let request = serde_json::to_string(&Request { let request = serde_json::to_string(&Request {
@ -330,6 +383,16 @@ impl Client {
send?; send?;
let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse(); let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
let mut cancel_fut = pin!(
match cancel_rx {
Some(rx) => future::Either::Left(async {
rx.await.log_err();
}),
None => future::Either::Right(future::pending()),
}
.fuse()
);
select! { select! {
response = rx.fuse() => { response = rx.fuse() => {
let elapsed = started.elapsed(); let elapsed = started.elapsed();
@ -348,6 +411,16 @@ impl Client {
Err(_) => anyhow::bail!("cancelled") Err(_) => anyhow::bail!("cancelled")
} }
} }
_ = cancel_fut => {
self.notify(
Cancelled::METHOD,
ClientNotification::Cancelled(CancelledParams {
request_id: RequestId::Int(id),
reason: None
})
).log_err();
anyhow::bail!("Request cancelled")
}
_ = timeout => { _ = timeout => {
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT); log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
anyhow::bail!("Context server request timeout"); anyhow::bail!("Context server request timeout");

View file

@ -6,6 +6,9 @@
//! of messages. //! of messages.
use anyhow::Result; use anyhow::Result;
use futures::channel::oneshot;
use gpui::AsyncApp;
use serde_json::Value;
use crate::client::Client; use crate::client::Client;
use crate::types::{self, Notification, Request}; use crate::types::{self, Notification, Request};
@ -95,7 +98,24 @@ impl InitializedContextServerProtocol {
self.inner.request(T::METHOD, params).await self.inner.request(T::METHOD, params).await
} }
pub async fn cancellable_request<T: Request>(
&self,
params: T::Params,
cancel_rx: oneshot::Receiver<()>,
) -> Result<T::Response> {
self.inner
.cancellable_request(T::METHOD, params, cancel_rx)
.await
}
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> { pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
self.inner.notify(T::METHOD, params) self.inner.notify(T::METHOD, params)
} }
pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncApp),
{
self.inner.on_notification(method, f);
}
} }

View file

@ -3,6 +3,8 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
use crate::client::RequestId;
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26"; pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
pub const VERSION_2024_11_05: &str = "2024-11-05"; pub const VERSION_2024_11_05: &str = "2024-11-05";
@ -100,6 +102,7 @@ pub mod notifications {
notification!("notifications/initialized", Initialized, ()); notification!("notifications/initialized", Initialized, ());
notification!("notifications/progress", Progress, ProgressParams); notification!("notifications/progress", Progress, ProgressParams);
notification!("notifications/message", Message, MessageParams); notification!("notifications/message", Message, MessageParams);
notification!("notifications/cancelled", Cancelled, CancelledParams);
notification!( notification!(
"notifications/resources/updated", "notifications/resources/updated",
ResourcesUpdated, ResourcesUpdated,
@ -617,11 +620,14 @@ pub enum ClientNotification {
Initialized, Initialized,
Progress(ProgressParams), Progress(ProgressParams),
RootsListChanged, RootsListChanged,
Cancelled { Cancelled(CancelledParams),
request_id: String, }
#[serde(skip_serializing_if = "Option::is_none")]
reason: Option<String>, #[derive(Debug, Serialize, Deserialize)]
}, pub struct CancelledParams {
pub request_id: RequestId,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]