Claude experiment (#34577)
Release Notes: - N/A --------- Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com> Co-authored-by: Anthony Eid <hello@anthonyeid.me> Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com> Co-authored-by: Nathan Sobo <nathan@zed.dev> Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
parent
5b97cd1900
commit
8e4555455c
33 changed files with 3437 additions and 1170 deletions
|
@ -1,5 +1,5 @@
|
|||
[package]
|
||||
name = "acp"
|
||||
name = "acp_thread"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
|
@ -9,15 +9,13 @@ license = "GPL-3.0-or-later"
|
|||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/acp.rs"
|
||||
path = "src/acp_thread.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = ["gpui/test-support", "project/test-support"]
|
||||
gemini = []
|
||||
|
||||
[dependencies]
|
||||
agent_servers.workspace = true
|
||||
agentic-coding-protocol.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
|
@ -29,6 +27,8 @@ itertools.workspace = true
|
|||
language.workspace = true
|
||||
markdown.workspace = true
|
||||
project.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
ui.workspace = true
|
||||
|
@ -41,7 +41,6 @@ env_logger.workspace = true
|
|||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
serde_json.workspace = true
|
||||
tempfile.workspace = true
|
||||
util.workspace = true
|
||||
settings.workspace = true
|
|
@ -1,7 +1,12 @@
|
|||
mod connection;
|
||||
pub use connection::*;
|
||||
|
||||
pub use acp::ToolCallId;
|
||||
use agent_servers::AgentServer;
|
||||
use agentic_coding_protocol::{self as acp, ToolCallLocation, UserMessageChunk};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use agentic_coding_protocol::{
|
||||
self as acp, AgentRequest, ProtocolVersion, ToolCallConfirmationOutcome, ToolCallLocation,
|
||||
UserMessageChunk,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use assistant_tool::ActionLog;
|
||||
use buffer_diff::BufferDiff;
|
||||
use editor::{Bias, MultiBuffer, PathKey};
|
||||
|
@ -97,7 +102,7 @@ pub struct AssistantMessage {
|
|||
}
|
||||
|
||||
impl AssistantMessage {
|
||||
fn to_markdown(&self, cx: &App) -> String {
|
||||
pub fn to_markdown(&self, cx: &App) -> String {
|
||||
format!(
|
||||
"## Assistant\n\n{}\n\n",
|
||||
self.chunks
|
||||
|
@ -455,9 +460,8 @@ pub struct AcpThread {
|
|||
action_log: Entity<ActionLog>,
|
||||
shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
|
||||
send_task: Option<Task<()>>,
|
||||
connection: Arc<acp::AgentConnection>,
|
||||
connection: Arc<dyn AgentConnection>,
|
||||
child_status: Option<Task<Result<()>>>,
|
||||
_io_task: Task<()>,
|
||||
}
|
||||
|
||||
pub enum AcpThreadEvent {
|
||||
|
@ -476,7 +480,11 @@ pub enum ThreadStatus {
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LoadError {
|
||||
Unsupported { current_version: SharedString },
|
||||
Unsupported {
|
||||
error_message: SharedString,
|
||||
upgrade_message: SharedString,
|
||||
upgrade_command: String,
|
||||
},
|
||||
Exited(i32),
|
||||
Other(SharedString),
|
||||
}
|
||||
|
@ -484,13 +492,7 @@ pub enum LoadError {
|
|||
impl Display for LoadError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
LoadError::Unsupported { current_version } => {
|
||||
write!(
|
||||
f,
|
||||
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
|
||||
current_version
|
||||
)
|
||||
}
|
||||
LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
|
||||
LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
|
||||
LoadError::Other(msg) => write!(f, "{}", msg),
|
||||
}
|
||||
|
@ -500,75 +502,38 @@ impl Display for LoadError {
|
|||
impl Error for LoadError {}
|
||||
|
||||
impl AcpThread {
|
||||
pub async fn spawn(
|
||||
server: impl AgentServer + 'static,
|
||||
root_dir: &Path,
|
||||
pub fn new(
|
||||
connection: impl AgentConnection + 'static,
|
||||
title: SharedString,
|
||||
child_status: Option<Task<Result<()>>>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Self>> {
|
||||
let command = match server.command(&project, cx).await {
|
||||
Ok(command) => command,
|
||||
Err(e) => return Err(anyhow!(LoadError::Other(format!("{e}").into()))),
|
||||
};
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
Self {
|
||||
action_log,
|
||||
shared_buffers: Default::default(),
|
||||
entries: Default::default(),
|
||||
title,
|
||||
project,
|
||||
send_task: None,
|
||||
connection: Arc::new(connection),
|
||||
child_status,
|
||||
}
|
||||
}
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
cx.new(|cx| {
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
|
||||
let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
|
||||
AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
|
||||
stdin,
|
||||
stdout,
|
||||
move |fut| foreground_executor.spawn(fut).detach(),
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn(async move {
|
||||
io_fut.await.log_err();
|
||||
});
|
||||
|
||||
let child_status = cx.background_spawn(async move {
|
||||
match child.status().await {
|
||||
Err(e) => Err(anyhow!(e)),
|
||||
Ok(result) if result.success() => Ok(()),
|
||||
Ok(result) => {
|
||||
if let Some(version) = server.version(&command).await.log_err()
|
||||
&& !version.supported
|
||||
{
|
||||
Err(anyhow!(LoadError::Unsupported {
|
||||
current_version: version.current_version
|
||||
}))
|
||||
} else {
|
||||
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
Self {
|
||||
action_log,
|
||||
shared_buffers: Default::default(),
|
||||
entries: Default::default(),
|
||||
title: "ACP Thread".into(),
|
||||
project,
|
||||
send_task: None,
|
||||
connection: Arc::new(connection),
|
||||
child_status: Some(child_status),
|
||||
_io_task: io_task,
|
||||
}
|
||||
})
|
||||
/// Send a request to the agent and wait for a response.
|
||||
pub fn request<R: AgentRequest + 'static>(
|
||||
&self,
|
||||
params: R,
|
||||
) -> impl use<R> + Future<Output = Result<R::Response>> {
|
||||
let params = params.into_any();
|
||||
let result = self.connection.request_any(params);
|
||||
async move {
|
||||
let result = result.await?;
|
||||
Ok(R::response_from_any(result)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn action_log(&self) -> &Entity<ActionLog> {
|
||||
|
@ -579,45 +544,6 @@ impl AcpThread {
|
|||
&self.project
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn fake(
|
||||
stdin: async_pipe::PipeWriter,
|
||||
stdout: async_pipe::PipeReader,
|
||||
project: Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
|
||||
let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
|
||||
AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
|
||||
stdin,
|
||||
stdout,
|
||||
move |fut| {
|
||||
foreground_executor.spawn(fut).detach();
|
||||
},
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn({
|
||||
async move {
|
||||
io_fut.await.log_err();
|
||||
}
|
||||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
Self {
|
||||
action_log,
|
||||
shared_buffers: Default::default(),
|
||||
entries: Default::default(),
|
||||
title: "ACP Thread".into(),
|
||||
project,
|
||||
send_task: None,
|
||||
connection: Arc::new(connection),
|
||||
child_status: None,
|
||||
_io_task: io_task,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn title(&self) -> SharedString {
|
||||
self.title.clone()
|
||||
}
|
||||
|
@ -711,7 +637,7 @@ impl AcpThread {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn request_tool_call(
|
||||
pub fn request_new_tool_call(
|
||||
&mut self,
|
||||
tool_call: acp::RequestToolCallConfirmationParams,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -731,6 +657,30 @@ impl AcpThread {
|
|||
ToolCallRequest { id, outcome: rx }
|
||||
}
|
||||
|
||||
pub fn request_tool_call_confirmation(
|
||||
&mut self,
|
||||
tool_call_id: ToolCallId,
|
||||
confirmation: acp::ToolCallConfirmation,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<ToolCallRequest> {
|
||||
let project = self.project.read(cx).languages().clone();
|
||||
let Some((_, call)) = self.tool_call_mut(tool_call_id) else {
|
||||
anyhow::bail!("Tool call not found");
|
||||
};
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
call.status = ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx),
|
||||
respond_tx: tx,
|
||||
};
|
||||
|
||||
Ok(ToolCallRequest {
|
||||
id: tool_call_id,
|
||||
outcome: rx,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn push_tool_call(
|
||||
&mut self,
|
||||
request: acp::PushToolCallParams,
|
||||
|
@ -912,19 +862,17 @@ impl AcpThread {
|
|||
false
|
||||
}
|
||||
|
||||
pub fn initialize(
|
||||
&self,
|
||||
) -> impl use<> + Future<Output = Result<acp::InitializeResponse, acp::Error>> {
|
||||
let connection = self.connection.clone();
|
||||
async move { connection.initialize().await }
|
||||
pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp::InitializeResponse>> {
|
||||
self.request(acp::InitializeParams {
|
||||
protocol_version: ProtocolVersion::latest(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authenticate(&self) -> impl use<> + Future<Output = Result<(), acp::Error>> {
|
||||
let connection = self.connection.clone();
|
||||
async move { connection.request(acp::AuthenticateParams).await }
|
||||
pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.request(acp::AuthenticateParams)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn send_raw(
|
||||
&mut self,
|
||||
message: &str,
|
||||
|
@ -945,7 +893,6 @@ impl AcpThread {
|
|||
message: acp::SendUserMessageParams,
|
||||
cx: &mut Context<Self>,
|
||||
) -> BoxFuture<'static, Result<(), acp::Error>> {
|
||||
let agent = self.connection.clone();
|
||||
self.push_entry(
|
||||
AgentThreadEntry::UserMessage(UserMessage::from_acp(
|
||||
&message,
|
||||
|
@ -959,11 +906,16 @@ impl AcpThread {
|
|||
let cancel = self.cancel(cx);
|
||||
|
||||
self.send_task = Some(cx.spawn(async move |this, cx| {
|
||||
cancel.await.log_err();
|
||||
async {
|
||||
cancel.await.log_err();
|
||||
|
||||
let result = agent.request(message).await;
|
||||
tx.send(result).log_err();
|
||||
this.update(cx, |this, _cx| this.send_task.take()).log_err();
|
||||
let result = this.update(cx, |this, _| this.request(message))?.await;
|
||||
tx.send(result).log_err();
|
||||
this.update(cx, |this, _cx| this.send_task.take())?;
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.await
|
||||
.log_err();
|
||||
}));
|
||||
|
||||
async move {
|
||||
|
@ -976,12 +928,10 @@ impl AcpThread {
|
|||
}
|
||||
|
||||
pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
|
||||
let agent = self.connection.clone();
|
||||
|
||||
if self.send_task.take().is_some() {
|
||||
let request = self.request(acp::CancelSendMessageParams);
|
||||
cx.spawn(async move |this, cx| {
|
||||
agent.request(acp::CancelSendMessageParams).await?;
|
||||
|
||||
request.await?;
|
||||
this.update(cx, |this, _cx| {
|
||||
for entry in this.entries.iter_mut() {
|
||||
if let AgentThreadEntry::ToolCall(call) = entry {
|
||||
|
@ -1019,6 +969,7 @@ impl AcpThread {
|
|||
pub fn read_text_file(
|
||||
&self,
|
||||
request: acp::ReadTextFileParams,
|
||||
reuse_shared_snapshot: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<String>> {
|
||||
let project = self.project.clone();
|
||||
|
@ -1032,28 +983,60 @@ impl AcpThread {
|
|||
});
|
||||
let buffer = load??.await?;
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
project.update(cx, |project, cx| {
|
||||
let position = buffer
|
||||
.read(cx)
|
||||
.snapshot()
|
||||
.anchor_before(Point::new(request.line.unwrap_or_default(), 0));
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
|
||||
let snapshot = if reuse_shared_snapshot {
|
||||
this.read_with(cx, |this, _| {
|
||||
this.shared_buffers.get(&buffer.clone()).cloned()
|
||||
})
|
||||
.log_err()
|
||||
.flatten()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let snapshot = if let Some(snapshot) = snapshot {
|
||||
snapshot
|
||||
} else {
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
project.update(cx, |project, cx| {
|
||||
let position = buffer
|
||||
.read(cx)
|
||||
.snapshot()
|
||||
.anchor_before(Point::new(request.line.unwrap_or_default(), 0));
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
|
||||
buffer.update(cx, |buffer, _| buffer.snapshot())?
|
||||
};
|
||||
|
||||
this.update(cx, |this, _| {
|
||||
let text = snapshot.text();
|
||||
this.shared_buffers.insert(buffer.clone(), snapshot);
|
||||
text
|
||||
})
|
||||
if request.line.is_none() && request.limit.is_none() {
|
||||
return Ok(text);
|
||||
}
|
||||
let limit = request.limit.unwrap_or(u32::MAX) as usize;
|
||||
let Some(line) = request.line else {
|
||||
return Ok(text.lines().take(limit).collect::<String>());
|
||||
};
|
||||
|
||||
let count = text.lines().count();
|
||||
if count < line as usize {
|
||||
anyhow::bail!("There are only {} lines", count);
|
||||
}
|
||||
Ok(text
|
||||
.lines()
|
||||
.skip(line as usize + 1)
|
||||
.take(limit)
|
||||
.collect::<String>())
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1134,16 +1117,49 @@ impl AcpThread {
|
|||
}
|
||||
}
|
||||
|
||||
struct AcpClientDelegate {
|
||||
#[derive(Clone)]
|
||||
pub struct AcpClientDelegate {
|
||||
thread: WeakEntity<AcpThread>,
|
||||
cx: AsyncApp,
|
||||
// sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
|
||||
}
|
||||
|
||||
impl AcpClientDelegate {
|
||||
fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
|
||||
pub fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
|
||||
Self { thread, cx }
|
||||
}
|
||||
|
||||
pub async fn request_existing_tool_call_confirmation(
|
||||
&self,
|
||||
tool_call_id: ToolCallId,
|
||||
confirmation: acp::ToolCallConfirmation,
|
||||
) -> Result<ToolCallConfirmationOutcome> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let ToolCallRequest { outcome, .. } = cx
|
||||
.update(|cx| {
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_confirmation(tool_call_id, confirmation, cx)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")??;
|
||||
|
||||
Ok(outcome.await?)
|
||||
}
|
||||
|
||||
pub async fn read_text_file_reusing_snapshot(
|
||||
&self,
|
||||
request: acp::ReadTextFileParams,
|
||||
) -> Result<acp::ReadTextFileResponse, acp::Error> {
|
||||
let content = self
|
||||
.cx
|
||||
.update(|cx| {
|
||||
self.thread
|
||||
.update(cx, |thread, cx| thread.read_text_file(request, true, cx))
|
||||
})?
|
||||
.context("Failed to update thread")?
|
||||
.await?;
|
||||
Ok(acp::ReadTextFileResponse { content })
|
||||
}
|
||||
}
|
||||
|
||||
impl acp::Client for AcpClientDelegate {
|
||||
|
@ -1172,7 +1188,7 @@ impl acp::Client for AcpClientDelegate {
|
|||
let ToolCallRequest { id, outcome } = cx
|
||||
.update(|cx| {
|
||||
self.thread
|
||||
.update(cx, |thread, cx| thread.request_tool_call(request, cx))
|
||||
.update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
|
||||
})?
|
||||
.context("Failed to update thread")?;
|
||||
|
||||
|
@ -1218,7 +1234,7 @@ impl acp::Client for AcpClientDelegate {
|
|||
.cx
|
||||
.update(|cx| {
|
||||
self.thread
|
||||
.update(cx, |thread, cx| thread.read_text_file(request, cx))
|
||||
.update(cx, |thread, cx| thread.read_text_file(request, false, cx))
|
||||
})?
|
||||
.context("Failed to update thread")?
|
||||
.await?;
|
||||
|
@ -1260,7 +1276,7 @@ pub struct ToolCallRequest {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use agent_servers::{AgentServerCommand, AgentServerVersion};
|
||||
use anyhow::anyhow;
|
||||
use async_pipe::{PipeReader, PipeWriter};
|
||||
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
||||
use gpui::{AsyncApp, TestAppContext};
|
||||
|
@ -1269,7 +1285,7 @@ mod tests {
|
|||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use smol::{future::BoxedLocal, stream::StreamExt as _};
|
||||
use std::{cell::RefCell, env, path::Path, rc::Rc, time::Duration};
|
||||
use std::{cell::RefCell, rc::Rc, time::Duration};
|
||||
use util::path;
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
|
@ -1515,265 +1531,6 @@ mod tests {
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_basic(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.entries.len(), 2);
|
||||
assert!(matches!(
|
||||
thread.entries[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
thread.entries[1],
|
||||
AgentThreadEntry::AssistantMessage(_)
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
std::fs::write(
|
||||
tempdir.path().join("foo.rs"),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
println!(\"Hello, world!\");
|
||||
}
|
||||
"},
|
||||
)
|
||||
.expect("failed to write file");
|
||||
let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
|
||||
let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
acp::SendUserMessageParams {
|
||||
chunks: vec![
|
||||
acp::UserMessageChunk::Text {
|
||||
text: "Read the file ".into(),
|
||||
},
|
||||
acp::UserMessageChunk::Path {
|
||||
path: Path::new("foo.rs").into(),
|
||||
},
|
||||
acp::UserMessageChunk::Text {
|
||||
text: " and tell me what the content of the println! is".into(),
|
||||
},
|
||||
],
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries.len(), 3);
|
||||
assert!(matches!(
|
||||
thread.entries[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
));
|
||||
assert!(matches!(thread.entries[1], AgentThreadEntry::ToolCall(_)));
|
||||
let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries[2] else {
|
||||
panic!("Expected AssistantMessage")
|
||||
};
|
||||
assert!(
|
||||
assistant_message.to_markdown(cx).contains("Hello, world!"),
|
||||
"unexpected assistant message: {:?}",
|
||||
assistant_message.to_markdown(cx)
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_tool_call(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/private/tmp"),
|
||||
json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Read the '/private/tmp/foo' file and tell me what you see.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert!(matches!(
|
||||
&thread.entries()[2],
|
||||
AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
|
||||
assert!(matches!(
|
||||
thread.entries[3],
|
||||
AgentThreadEntry::AssistantMessage(_)
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
|
||||
});
|
||||
|
||||
run_until_first_tool_call(&thread, cx).await;
|
||||
|
||||
let tool_call_id = thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[2]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
assert_eq!(root_command, "echo");
|
||||
|
||||
*id
|
||||
});
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
|
||||
|
||||
assert!(matches!(
|
||||
&thread.entries()[2],
|
||||
AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
});
|
||||
|
||||
full_turn.await.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
content: Some(ToolCallContent::Markdown { markdown }),
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
}) = &thread.entries()[2]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
markdown.read_with(cx, |md, _cx| {
|
||||
assert!(
|
||||
md.source().contains("Hello, world!"),
|
||||
r#"Expected '{}' to contain "Hello, world!""#,
|
||||
md.source()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_cancel(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
|
||||
});
|
||||
|
||||
let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
|
||||
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
else {
|
||||
panic!("{:?}", thread.entries()[1]);
|
||||
};
|
||||
|
||||
assert_eq!(root_command, "echo");
|
||||
|
||||
*id
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.cancel(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
full_turn.await.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Canceled,
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(
|
||||
&thread.entries().last().unwrap(),
|
||||
AgentThreadEntry::AssistantMessage(..),
|
||||
))
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_until_first_tool_call(
|
||||
thread: &Entity<AcpThread>,
|
||||
cx: &mut TestAppContext,
|
||||
|
@ -1801,66 +1558,39 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn gemini_acp_thread(
|
||||
project: Entity<Project>,
|
||||
current_dir: impl AsRef<Path>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Entity<AcpThread> {
|
||||
struct DevGemini;
|
||||
|
||||
impl agent_servers::AgentServer for DevGemini {
|
||||
async fn command(
|
||||
&self,
|
||||
_project: &Entity<Project>,
|
||||
_cx: &mut AsyncApp,
|
||||
) -> Result<agent_servers::AgentServerCommand> {
|
||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../../gemini-cli/packages/cli")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
Ok(AgentServerCommand {
|
||||
path: "node".into(),
|
||||
args: vec![cli_path, "--experimental-acp".into()],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn version(
|
||||
&self,
|
||||
_command: &agent_servers::AgentServerCommand,
|
||||
) -> Result<AgentServerVersion> {
|
||||
Ok(AgentServerVersion {
|
||||
current_version: "0.1.0".into(),
|
||||
supported: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let thread = AcpThread::spawn(DevGemini, current_dir.as_ref(), project, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _| thread.initialize())
|
||||
.await
|
||||
.unwrap();
|
||||
thread
|
||||
}
|
||||
|
||||
pub fn fake_acp_thread(
|
||||
project: Entity<Project>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
|
||||
let (stdin_tx, stdin_rx) = async_pipe::pipe();
|
||||
let (stdout_tx, stdout_rx) = async_pipe::pipe();
|
||||
let thread = cx.update(|cx| cx.new(|cx| AcpThread::fake(stdin_tx, stdout_rx, project, cx)));
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
|
||||
AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
|
||||
stdin_tx,
|
||||
stdout_rx,
|
||||
move |fut| {
|
||||
foreground_executor.spawn(fut).detach();
|
||||
},
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn({
|
||||
async move {
|
||||
io_fut.await.log_err();
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
AcpThread::new(connection, "Test".into(), Some(io_task), project, cx)
|
||||
});
|
||||
let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
|
||||
(thread, agent)
|
||||
}
|
||||
|
||||
pub struct FakeAcpServer {
|
||||
connection: acp::ClientConnection,
|
||||
|
||||
_io_task: Task<()>,
|
||||
on_user_message: Option<
|
||||
Rc<
|
20
crates/acp_thread/src/connection.rs
Normal file
20
crates/acp_thread/src/connection.rs
Normal file
|
@ -0,0 +1,20 @@
|
|||
use agentic_coding_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use futures::future::{FutureExt as _, LocalBoxFuture};
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn request_any(
|
||||
&self,
|
||||
params: acp::AnyAgentRequest,
|
||||
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>>;
|
||||
}
|
||||
|
||||
impl AgentConnection for acp::AgentConnection {
|
||||
fn request_any(
|
||||
&self,
|
||||
params: acp::AnyAgentRequest,
|
||||
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
|
||||
let task = self.request_any(params);
|
||||
async move { Ok(task.await?) }.boxed_local()
|
||||
}
|
||||
}
|
|
@ -5,6 +5,10 @@ edition.workspace = true
|
|||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[features]
|
||||
test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support"]
|
||||
gemini = []
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
|
@ -13,15 +17,32 @@ path = "src/agent_servers.rs"
|
|||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
agentic-coding-protocol.workspace = true
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
log.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
tempfile.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
watch.workspace = true
|
||||
which.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger.workspace = true
|
||||
language.workspace = true
|
||||
indoc.workspace = true
|
||||
acp_thread = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
@ -1,30 +1,24 @@
|
|||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
mod claude;
|
||||
mod gemini;
|
||||
mod settings;
|
||||
mod stdio_agent_server;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
pub use claude::*;
|
||||
pub use gemini::*;
|
||||
pub use settings::*;
|
||||
pub use stdio_agent_server::*;
|
||||
|
||||
use acp_thread::AcpThread;
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AsyncApp, Entity, SharedString};
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources, SettingsStore};
|
||||
use util::{ResultExt, paths};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
AllAgentServersSettings::register(cx);
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
pub struct AllAgentServersSettings {
|
||||
gemini: Option<AgentServerSettings>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
pub struct AgentServerSettings {
|
||||
#[serde(flatten)]
|
||||
command: AgentServerCommand,
|
||||
settings::init(cx);
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
|
||||
|
@ -36,153 +30,28 @@ pub struct AgentServerCommand {
|
|||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
pub struct Gemini;
|
||||
|
||||
pub struct AgentServerVersion {
|
||||
pub current_version: SharedString,
|
||||
pub supported: bool,
|
||||
pub enum AgentServerVersion {
|
||||
Supported,
|
||||
Unsupported {
|
||||
error_message: SharedString,
|
||||
upgrade_message: SharedString,
|
||||
upgrade_command: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait AgentServer: Send {
|
||||
fn command(
|
||||
fn logo(&self) -> ui::IconName;
|
||||
fn name(&self) -> &'static str;
|
||||
fn empty_state_headline(&self) -> &'static str;
|
||||
fn empty_state_message(&self) -> &'static str;
|
||||
fn supports_always_allow(&self) -> bool;
|
||||
|
||||
fn new_thread(
|
||||
&self,
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> impl Future<Output = Result<AgentServerCommand>>;
|
||||
|
||||
fn version(
|
||||
&self,
|
||||
command: &AgentServerCommand,
|
||||
) -> impl Future<Output = Result<AgentServerVersion>> + Send;
|
||||
}
|
||||
|
||||
const GEMINI_ACP_ARG: &str = "--experimental-acp";
|
||||
|
||||
impl AgentServer for Gemini {
|
||||
async fn command(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<AgentServerCommand> {
|
||||
let custom_command = cx.read_global(|settings: &SettingsStore, _| {
|
||||
let settings = settings.get::<AllAgentServersSettings>(None);
|
||||
settings
|
||||
.gemini
|
||||
.as_ref()
|
||||
.map(|gemini_settings| AgentServerCommand {
|
||||
path: gemini_settings.command.path.clone(),
|
||||
args: gemini_settings
|
||||
.command
|
||||
.args
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(std::iter::once(GEMINI_ACP_ARG.into()))
|
||||
.collect(),
|
||||
env: gemini_settings.command.env.clone(),
|
||||
})
|
||||
})?;
|
||||
|
||||
if let Some(custom_command) = custom_command {
|
||||
return Ok(custom_command);
|
||||
}
|
||||
|
||||
if let Some(path) = find_bin_in_path("gemini", project, cx).await {
|
||||
return Ok(AgentServerCommand {
|
||||
path,
|
||||
args: vec![GEMINI_ACP_ARG.into()],
|
||||
env: None,
|
||||
});
|
||||
}
|
||||
|
||||
let (fs, node_runtime) = project.update(cx, |project, _| {
|
||||
(project.fs().clone(), project.node_runtime().cloned())
|
||||
})?;
|
||||
let node_runtime = node_runtime.context("gemini not found on path")?;
|
||||
|
||||
let directory = ::paths::agent_servers_dir().join("gemini");
|
||||
fs.create_dir(&directory).await?;
|
||||
node_runtime
|
||||
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
|
||||
.await?;
|
||||
let path = directory.join("node_modules/.bin/gemini");
|
||||
|
||||
Ok(AgentServerCommand {
|
||||
path,
|
||||
args: vec![GEMINI_ACP_ARG.into()],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
|
||||
let version_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--version")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let help_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--help")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
|
||||
|
||||
let current_version = String::from_utf8(version_output?.stdout)?.into();
|
||||
let supported = String::from_utf8(help_output?.stdout)?.contains(GEMINI_ACP_ARG);
|
||||
|
||||
Ok(AgentServerVersion {
|
||||
current_version,
|
||||
supported,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn find_bin_in_path(
|
||||
bin_name: &'static str,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Option<PathBuf> {
|
||||
let (env_task, root_dir) = project
|
||||
.update(cx, |project, cx| {
|
||||
let worktree = project.visible_worktrees(cx).next();
|
||||
match worktree {
|
||||
Some(worktree) => {
|
||||
let env_task = project.environment().update(cx, |env, cx| {
|
||||
env.get_worktree_environment(worktree.clone(), cx)
|
||||
});
|
||||
|
||||
let path = worktree.read(cx).abs_path();
|
||||
(env_task, path)
|
||||
}
|
||||
None => {
|
||||
let path: Arc<Path> = paths::home_dir().as_path().into();
|
||||
let env_task = project.environment().update(cx, |env, cx| {
|
||||
env.get_directory_environment(path.clone(), cx)
|
||||
});
|
||||
(env_task, path)
|
||||
}
|
||||
}
|
||||
})
|
||||
.log_err()?;
|
||||
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let which_result = if cfg!(windows) {
|
||||
which::which(bin_name)
|
||||
} else {
|
||||
let env = env_task.await.unwrap_or_default();
|
||||
let shell_path = env.get("PATH").cloned();
|
||||
which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref())
|
||||
};
|
||||
|
||||
if let Err(which::Error::CannotFindBinaryPath) = which_result {
|
||||
return None;
|
||||
}
|
||||
|
||||
which_result.log_err()
|
||||
})
|
||||
.await
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AgentServerCommand {
|
||||
|
@ -209,23 +78,3 @@ impl std::fmt::Debug for AgentServerCommand {
|
|||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl settings::Settings for AllAgentServersSettings {
|
||||
const KEY: Option<&'static str> = Some("agent_servers");
|
||||
|
||||
type FileContent = Self;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
let mut settings = AllAgentServersSettings::default();
|
||||
|
||||
for value in sources.defaults_and_customizations() {
|
||||
if value.gemini.is_some() {
|
||||
settings.gemini = value.gemini.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
||||
}
|
||||
|
|
680
crates/agent_servers/src/claude.rs
Normal file
680
crates/agent_servers/src/claude.rs
Normal file
|
@ -0,0 +1,680 @@
|
|||
mod mcp_server;
|
||||
mod tools;
|
||||
|
||||
use collections::HashMap;
|
||||
use project::Project;
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
||||
use agentic_coding_protocol::{
|
||||
self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion,
|
||||
StreamAssistantMessageChunkParams, ToolCallContent, UpdateToolCallParams,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use futures::channel::oneshot;
|
||||
use futures::future::LocalBoxFuture;
|
||||
use futures::{AsyncBufReadExt, AsyncWriteExt};
|
||||
use futures::{
|
||||
AsyncRead, AsyncWrite, FutureExt, StreamExt,
|
||||
channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
|
||||
io::BufReader,
|
||||
select_biased,
|
||||
};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::claude::mcp_server::ClaudeMcpServer;
|
||||
use crate::claude::tools::ClaudeTool;
|
||||
use crate::{AgentServer, find_bin_in_path};
|
||||
use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ClaudeCode;
|
||||
|
||||
impl AgentServer for ClaudeCode {
|
||||
fn name(&self) -> &'static str {
|
||||
"Claude Code"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
self.name()
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
""
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiClaude
|
||||
}
|
||||
|
||||
fn supports_always_allow(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
&self,
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let project = project.clone();
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
let title = self.name().into();
|
||||
cx.spawn(async move |cx| {
|
||||
let (mut delegate_tx, delegate_rx) = watch::channel(None);
|
||||
let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let permission_mcp_server =
|
||||
ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
|
||||
|
||||
let mut mcp_servers = HashMap::default();
|
||||
mcp_servers.insert(
|
||||
mcp_server::SERVER_NAME.to_string(),
|
||||
permission_mcp_server.server_config()?,
|
||||
);
|
||||
let mcp_config = McpConfig { mcp_servers };
|
||||
|
||||
let mcp_config_file = tempfile::NamedTempFile::new()?;
|
||||
let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
|
||||
|
||||
let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
|
||||
mcp_config_file
|
||||
.write_all(serde_json::to_string(&mcp_config)?.as_bytes())
|
||||
.await?;
|
||||
mcp_config_file.flush().await?;
|
||||
|
||||
let command = find_bin_in_path("claude", &project, cx)
|
||||
.await
|
||||
.context("Failed to find claude binary")?;
|
||||
|
||||
let mut child = util::command::new_smol_command(&command)
|
||||
.args([
|
||||
"--input-format",
|
||||
"stream-json",
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--print",
|
||||
"--verbose",
|
||||
"--mcp-config",
|
||||
mcp_config_path.to_string_lossy().as_ref(),
|
||||
"--permission-prompt-tool",
|
||||
&format!(
|
||||
"mcp__{}__{}",
|
||||
mcp_server::SERVER_NAME,
|
||||
mcp_server::PERMISSION_TOOL
|
||||
),
|
||||
"--allowedTools",
|
||||
"mcp__zed__Read,mcp__zed__Edit",
|
||||
"--disallowedTools",
|
||||
"Read,Edit",
|
||||
])
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
|
||||
|
||||
let io_task =
|
||||
ClaudeAgentConnection::handle_io(outgoing_rx, incoming_message_tx, stdin, stdout);
|
||||
cx.background_spawn(async move {
|
||||
io_task.await.log_err();
|
||||
drop(mcp_config_path);
|
||||
drop(child);
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.new(|cx| {
|
||||
let end_turn_tx = Rc::new(RefCell::new(None));
|
||||
let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
|
||||
delegate_tx.send(Some(delegate.clone())).log_err();
|
||||
|
||||
let handler_task = cx.foreground_executor().spawn({
|
||||
let end_turn_tx = end_turn_tx.clone();
|
||||
let tool_id_map = tool_id_map.clone();
|
||||
async move {
|
||||
while let Some(message) = incoming_message_rx.next().await {
|
||||
ClaudeAgentConnection::handle_message(
|
||||
delegate.clone(),
|
||||
message,
|
||||
end_turn_tx.clone(),
|
||||
tool_id_map.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut connection = ClaudeAgentConnection {
|
||||
outgoing_tx,
|
||||
end_turn_tx,
|
||||
_handler_task: handler_task,
|
||||
_mcp_server: None,
|
||||
};
|
||||
|
||||
connection._mcp_server = Some(permission_mcp_server);
|
||||
acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentConnection for ClaudeAgentConnection {
|
||||
/// Send a request to the agent and wait for a response.
|
||||
fn request_any(
|
||||
&self,
|
||||
params: AnyAgentRequest,
|
||||
) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
|
||||
let end_turn_tx = self.end_turn_tx.clone();
|
||||
let outgoing_tx = self.outgoing_tx.clone();
|
||||
async move {
|
||||
match params {
|
||||
// todo: consider sending an empty request so we get the init response?
|
||||
AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
|
||||
acp::InitializeResponse {
|
||||
is_authenticated: true,
|
||||
protocol_version: ProtocolVersion::latest(),
|
||||
},
|
||||
)),
|
||||
AnyAgentRequest::AuthenticateParams(_) => {
|
||||
Err(anyhow!("Authentication not supported"))
|
||||
}
|
||||
AnyAgentRequest::SendUserMessageParams(message) => {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
end_turn_tx.borrow_mut().replace(tx);
|
||||
let mut content = String::new();
|
||||
for chunk in message.chunks {
|
||||
match chunk {
|
||||
agentic_coding_protocol::UserMessageChunk::Text { text } => {
|
||||
content.push_str(&text)
|
||||
}
|
||||
agentic_coding_protocol::UserMessageChunk::Path { path } => {
|
||||
content.push_str(&format!("@{path:?}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
outgoing_tx.unbounded_send(SdkMessage::User {
|
||||
message: Message {
|
||||
role: Role::User,
|
||||
content: Content::UntaggedText(content),
|
||||
id: None,
|
||||
model: None,
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: None,
|
||||
},
|
||||
session_id: None,
|
||||
})?;
|
||||
rx.await??;
|
||||
Ok(AnyAgentResult::SendUserMessageResponse(
|
||||
acp::SendUserMessageResponse,
|
||||
))
|
||||
}
|
||||
AnyAgentRequest::CancelSendMessageParams(_) => Ok(
|
||||
AnyAgentResult::CancelSendMessageResponse(acp::CancelSendMessageResponse),
|
||||
),
|
||||
}
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
}
|
||||
|
||||
struct ClaudeAgentConnection {
|
||||
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
||||
_mcp_server: Option<ClaudeMcpServer>,
|
||||
_handler_task: Task<()>,
|
||||
}
|
||||
|
||||
impl ClaudeAgentConnection {
|
||||
async fn handle_message(
|
||||
delegate: AcpClientDelegate,
|
||||
message: SdkMessage,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
||||
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
) {
|
||||
match message {
|
||||
SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => {
|
||||
for chunk in message.content.chunks() {
|
||||
match chunk {
|
||||
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
|
||||
delegate
|
||||
.stream_assistant_message_chunk(StreamAssistantMessageChunkParams {
|
||||
chunk: acp::AssistantMessageChunk::Text { text },
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::ToolUse { id, name, input } => {
|
||||
if let Some(resp) = delegate
|
||||
.push_tool_call(ClaudeTool::infer(&name, input).as_acp())
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
tool_id_map.borrow_mut().insert(id, resp.id);
|
||||
}
|
||||
}
|
||||
ContentChunk::ToolResult {
|
||||
content,
|
||||
tool_use_id,
|
||||
} => {
|
||||
let id = tool_id_map.borrow_mut().remove(&tool_use_id);
|
||||
if let Some(id) = id {
|
||||
delegate
|
||||
.update_tool_call(UpdateToolCallParams {
|
||||
tool_call_id: id,
|
||||
status: acp::ToolCallStatus::Finished,
|
||||
content: Some(ToolCallContent::Markdown {
|
||||
// For now we only include text content
|
||||
markdown: content.to_string(),
|
||||
}),
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::Thinking
|
||||
| ContentChunk::RedactedThinking
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
delegate
|
||||
.stream_assistant_message_chunk(StreamAssistantMessageChunkParams {
|
||||
chunk: acp::AssistantMessageChunk::Text {
|
||||
text: format!("Unsupported content: {:?}", chunk),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
SdkMessage::Result {
|
||||
is_error, subtype, ..
|
||||
} => {
|
||||
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
|
||||
if is_error {
|
||||
end_turn_tx.send(Err(anyhow!("Error: {subtype}"))).ok();
|
||||
} else {
|
||||
end_turn_tx.send(Ok(())).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
SdkMessage::System { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_io(
|
||||
mut outgoing_rx: UnboundedReceiver<SdkMessage>,
|
||||
incoming_tx: UnboundedSender<SdkMessage>,
|
||||
mut outgoing_bytes: impl Unpin + AsyncWrite,
|
||||
incoming_bytes: impl Unpin + AsyncRead,
|
||||
) -> Result<()> {
|
||||
let mut output_reader = BufReader::new(incoming_bytes);
|
||||
let mut outgoing_line = Vec::new();
|
||||
let mut incoming_line = String::new();
|
||||
loop {
|
||||
select_biased! {
|
||||
message = outgoing_rx.next() => {
|
||||
if let Some(message) = message {
|
||||
outgoing_line.clear();
|
||||
serde_json::to_writer(&mut outgoing_line, &message)?;
|
||||
log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
|
||||
outgoing_line.push(b'\n');
|
||||
outgoing_bytes.write_all(&outgoing_line).await.ok();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
|
||||
if bytes_read? == 0 {
|
||||
break
|
||||
}
|
||||
log::trace!("recv: {}", &incoming_line);
|
||||
match serde_json::from_str::<SdkMessage>(&incoming_line) {
|
||||
Ok(message) => {
|
||||
incoming_tx.unbounded_send(message).log_err();
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
|
||||
}
|
||||
}
|
||||
incoming_line.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: Role,
|
||||
content: Content,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
model: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
stop_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
stop_sequence: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Content {
|
||||
UntaggedText(String),
|
||||
Chunks(Vec<ContentChunk>),
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
|
||||
match self {
|
||||
Self::Chunks(chunks) => chunks.into_iter(),
|
||||
Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Content {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Content::UntaggedText(txt) => write!(f, "{}", txt),
|
||||
Content::Chunks(chunks) => {
|
||||
for chunk in chunks {
|
||||
write!(f, "{}", chunk)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ContentChunk {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
ToolResult {
|
||||
content: Content,
|
||||
tool_use_id: String,
|
||||
},
|
||||
// TODO
|
||||
Image,
|
||||
Document,
|
||||
Thinking,
|
||||
RedactedThinking,
|
||||
WebSearchToolResult,
|
||||
#[serde(untagged)]
|
||||
UntaggedText(String),
|
||||
}
|
||||
|
||||
impl Display for ContentChunk {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ContentChunk::Text { text } => write!(f, "{}", text),
|
||||
ContentChunk::UntaggedText(text) => write!(f, "{}", text),
|
||||
ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::Thinking
|
||||
| ContentChunk::RedactedThinking
|
||||
| ContentChunk::ToolUse { .. }
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
write!(f, "\n{:?}\n", &self)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Usage {
|
||||
input_tokens: u32,
|
||||
cache_creation_input_tokens: u32,
|
||||
cache_read_input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
service_tier: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum Role {
|
||||
System,
|
||||
Assistant,
|
||||
User,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct MessageParam {
|
||||
role: Role,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum SdkMessage {
|
||||
// An assistant message
|
||||
Assistant {
|
||||
message: Message, // from Anthropic SDK
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
},
|
||||
|
||||
// A user message
|
||||
User {
|
||||
message: Message, // from Anthropic SDK
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
},
|
||||
|
||||
// Emitted as the last message in a conversation
|
||||
Result {
|
||||
subtype: ResultErrorType,
|
||||
duration_ms: f64,
|
||||
duration_api_ms: f64,
|
||||
is_error: bool,
|
||||
num_turns: i32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
result: Option<String>,
|
||||
session_id: String,
|
||||
total_cost_usd: f64,
|
||||
},
|
||||
// Emitted as the first message at the start of a conversation
|
||||
System {
|
||||
cwd: String,
|
||||
session_id: String,
|
||||
tools: Vec<String>,
|
||||
model: String,
|
||||
mcp_servers: Vec<McpServer>,
|
||||
#[serde(rename = "apiKeySource")]
|
||||
api_key_source: String,
|
||||
#[serde(rename = "permissionMode")]
|
||||
permission_mode: PermissionMode,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ResultErrorType {
|
||||
Success,
|
||||
ErrorMaxTurns,
|
||||
ErrorDuringExecution,
|
||||
}
|
||||
|
||||
impl Display for ResultErrorType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ResultErrorType::Success => write!(f, "success"),
|
||||
ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
|
||||
ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct McpServer {
|
||||
name: String,
|
||||
status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
enum PermissionMode {
|
||||
Default,
|
||||
AcceptEdits,
|
||||
BypassPermissions,
|
||||
Plan,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct McpConfig {
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct McpServerConfig {
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_content_untagged_text() {
|
||||
let json = json!("Hello, world!");
|
||||
let content: Content = serde_json::from_value(json).unwrap();
|
||||
match content {
|
||||
Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
|
||||
_ => panic!("Expected UntaggedText variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_content_chunks() {
|
||||
let json = json!([
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "tool_123",
|
||||
"name": "calculator",
|
||||
"input": {"operation": "add", "a": 1, "b": 2}
|
||||
}
|
||||
]);
|
||||
let content: Content = serde_json::from_value(json).unwrap();
|
||||
match content {
|
||||
Content::Chunks(chunks) => {
|
||||
assert_eq!(chunks.len(), 2);
|
||||
match &chunks[0] {
|
||||
ContentChunk::Text { text } => assert_eq!(text, "Hello"),
|
||||
_ => panic!("Expected Text chunk"),
|
||||
}
|
||||
match &chunks[1] {
|
||||
ContentChunk::ToolUse { id, name, input } => {
|
||||
assert_eq!(id, "tool_123");
|
||||
assert_eq!(name, "calculator");
|
||||
assert_eq!(input["operation"], "add");
|
||||
assert_eq!(input["a"], 1);
|
||||
assert_eq!(input["b"], 2);
|
||||
}
|
||||
_ => panic!("Expected ToolUse chunk"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Chunks variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_tool_result_untagged_text() {
|
||||
let json = json!({
|
||||
"type": "tool_result",
|
||||
"content": "Result content",
|
||||
"tool_use_id": "tool_456"
|
||||
});
|
||||
let chunk: ContentChunk = serde_json::from_value(json).unwrap();
|
||||
match chunk {
|
||||
ContentChunk::ToolResult {
|
||||
content,
|
||||
tool_use_id,
|
||||
} => {
|
||||
match content {
|
||||
Content::UntaggedText(text) => assert_eq!(text, "Result content"),
|
||||
_ => panic!("Expected UntaggedText content"),
|
||||
}
|
||||
assert_eq!(tool_use_id, "tool_456");
|
||||
}
|
||||
_ => panic!("Expected ToolResult variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_tool_result_chunks() {
|
||||
let json = json!({
|
||||
"type": "tool_result",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Processing complete"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Result: 42"
|
||||
}
|
||||
],
|
||||
"tool_use_id": "tool_789"
|
||||
});
|
||||
let chunk: ContentChunk = serde_json::from_value(json).unwrap();
|
||||
match chunk {
|
||||
ContentChunk::ToolResult {
|
||||
content,
|
||||
tool_use_id,
|
||||
} => {
|
||||
match content {
|
||||
Content::Chunks(chunks) => {
|
||||
assert_eq!(chunks.len(), 2);
|
||||
match &chunks[0] {
|
||||
ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
|
||||
_ => panic!("Expected Text chunk"),
|
||||
}
|
||||
match &chunks[1] {
|
||||
ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
|
||||
_ => panic!("Expected Text chunk"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Chunks content"),
|
||||
}
|
||||
assert_eq!(tool_use_id, "tool_789");
|
||||
}
|
||||
_ => panic!("Expected ToolResult variant"),
|
||||
}
|
||||
}
|
||||
}
|
303
crates/agent_servers/src/claude/mcp_server.rs
Normal file
303
crates/agent_servers/src/claude/mcp_server.rs
Normal file
|
@ -0,0 +1,303 @@
|
|||
use std::{cell::RefCell, rc::Rc};
|
||||
|
||||
use acp_thread::AcpClientDelegate;
|
||||
use agentic_coding_protocol::{self as acp, Client, ReadTextFileParams, WriteTextFileParams};
|
||||
use anyhow::{Context, Result};
|
||||
use collections::HashMap;
|
||||
use context_server::{
|
||||
listener::McpServer,
|
||||
types::{
|
||||
CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
|
||||
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
|
||||
ToolResponseContent, ToolsCapabilities, requests,
|
||||
},
|
||||
};
|
||||
use gpui::{App, AsyncApp, Task};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::debug_panic;
|
||||
|
||||
use crate::claude::{
|
||||
McpServerConfig,
|
||||
tools::{ClaudeTool, EditToolParams, EditToolResponse, ReadToolParams, ReadToolResponse},
|
||||
};
|
||||
|
||||
pub struct ClaudeMcpServer {
|
||||
server: McpServer,
|
||||
}
|
||||
|
||||
pub const SERVER_NAME: &str = "zed";
|
||||
pub const READ_TOOL: &str = "Read";
|
||||
pub const EDIT_TOOL: &str = "Edit";
|
||||
pub const PERMISSION_TOOL: &str = "Confirmation";
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
struct PermissionToolParams {
|
||||
tool_name: String,
|
||||
input: serde_json::Value,
|
||||
tool_use_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior,
|
||||
updated_input: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum PermissionToolBehavior {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
impl ClaudeMcpServer {
|
||||
pub async fn new(
|
||||
delegate: watch::Receiver<Option<AcpClientDelegate>>,
|
||||
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut mcp_server = McpServer::new(cx).await?;
|
||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
||||
mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
|
||||
mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
|
||||
Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx)
|
||||
});
|
||||
|
||||
Ok(Self { server: mcp_server })
|
||||
}
|
||||
|
||||
pub fn server_config(&self) -> Result<McpServerConfig> {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
let zed_path = util::get_shell_safe_zed_path()?;
|
||||
#[cfg(target_os = "windows")]
|
||||
let zed_path = std::env::current_exe()
|
||||
.context("finding current executable path for use in mcp_server")?
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
Ok(McpServerConfig {
|
||||
command: zed_path,
|
||||
args: vec![
|
||||
"--nc".into(),
|
||||
self.server.socket_path().display().to_string(),
|
||||
],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
Ok(InitializeResponse {
|
||||
protocol_version: ProtocolVersion("2025-06-18".into()),
|
||||
capabilities: ServerCapabilities {
|
||||
experimental: None,
|
||||
logging: None,
|
||||
completions: None,
|
||||
prompts: None,
|
||||
resources: None,
|
||||
tools: Some(ToolsCapabilities {
|
||||
list_changed: Some(false),
|
||||
}),
|
||||
},
|
||||
server_info: Implementation {
|
||||
name: SERVER_NAME.into(),
|
||||
version: "0.1.0".into(),
|
||||
},
|
||||
meta: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_list_tools(_: (), cx: &App) -> Task<Result<ListToolsResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
Ok(ListToolsResponse {
|
||||
tools: vec![
|
||||
Tool {
|
||||
name: PERMISSION_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(PermissionToolParams).into(),
|
||||
description: None,
|
||||
annotations: None,
|
||||
},
|
||||
Tool {
|
||||
name: READ_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(ReadToolParams).into(),
|
||||
description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()),
|
||||
annotations: Some(ToolAnnotations {
|
||||
title: Some("Read file".to_string()),
|
||||
read_only_hint: Some(true),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
// if time passes the contents might change, but it's not going to do anything different
|
||||
// true or false seem too strong, let's try a none.
|
||||
idempotent_hint: None,
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: EDIT_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(EditToolParams).into(),
|
||||
description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()),
|
||||
annotations: Some(ToolAnnotations {
|
||||
title: Some("Edit file".to_string()),
|
||||
read_only_hint: Some(false),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
idempotent_hint: Some(false),
|
||||
}),
|
||||
},
|
||||
],
|
||||
next_cursor: None,
|
||||
meta: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_call_tool(
|
||||
request: CallToolParams,
|
||||
mut delegate_watch: watch::Receiver<Option<AcpClientDelegate>>,
|
||||
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
cx: &App,
|
||||
) -> Task<Result<CallToolResponse>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let Some(delegate) = delegate_watch.recv().await? else {
|
||||
debug_panic!("Sent None delegate");
|
||||
anyhow::bail!("Server not available");
|
||||
};
|
||||
|
||||
if request.name.as_str() == PERMISSION_TOOL {
|
||||
let input =
|
||||
serde_json::from_value(request.arguments.context("Arguments required")?)?;
|
||||
|
||||
let result =
|
||||
Self::handle_permissions_tool_call(input, delegate, tool_id_map, cx).await?;
|
||||
Ok(CallToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&result)?,
|
||||
}],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
})
|
||||
} else if request.name.as_str() == READ_TOOL {
|
||||
let input =
|
||||
serde_json::from_value(request.arguments.context("Arguments required")?)?;
|
||||
|
||||
let result = Self::handle_read_tool_call(input, delegate, cx).await?;
|
||||
Ok(CallToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&result)?,
|
||||
}],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
})
|
||||
} else if request.name.as_str() == EDIT_TOOL {
|
||||
let input =
|
||||
serde_json::from_value(request.arguments.context("Arguments required")?)?;
|
||||
|
||||
let result = Self::handle_edit_tool_call(input, delegate, cx).await?;
|
||||
Ok(CallToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&result)?,
|
||||
}],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
})
|
||||
} else {
|
||||
anyhow::bail!("Unsupported tool");
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_read_tool_call(
|
||||
params: ReadToolParams,
|
||||
delegate: AcpClientDelegate,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<ReadToolResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let response = delegate
|
||||
.read_text_file(ReadTextFileParams {
|
||||
path: params.abs_path,
|
||||
line: params.offset,
|
||||
limit: params.limit,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(ReadToolResponse {
|
||||
content: response.content,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_edit_tool_call(
|
||||
params: EditToolParams,
|
||||
delegate: AcpClientDelegate,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<EditToolResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let response = delegate
|
||||
.read_text_file_reusing_snapshot(ReadTextFileParams {
|
||||
path: params.abs_path.clone(),
|
||||
line: None,
|
||||
limit: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let new_content = response.content.replace(¶ms.old_text, ¶ms.new_text);
|
||||
if new_content == response.content {
|
||||
return Err(anyhow::anyhow!("The old_text was not found in the content"));
|
||||
}
|
||||
|
||||
delegate
|
||||
.write_text_file(WriteTextFileParams {
|
||||
path: params.abs_path,
|
||||
content: new_content,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(EditToolResponse)
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_permissions_tool_call(
|
||||
params: PermissionToolParams,
|
||||
delegate: AcpClientDelegate,
|
||||
tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<PermissionToolResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone());
|
||||
|
||||
let tool_call_id = match params.tool_use_id {
|
||||
Some(tool_use_id) => tool_id_map
|
||||
.borrow()
|
||||
.get(&tool_use_id)
|
||||
.cloned()
|
||||
.context("Tool call ID not found")?,
|
||||
|
||||
None => delegate.push_tool_call(claude_tool.as_acp()).await?.id,
|
||||
};
|
||||
|
||||
let outcome = delegate
|
||||
.request_existing_tool_call_confirmation(
|
||||
tool_call_id,
|
||||
claude_tool.confirmation(None),
|
||||
)
|
||||
.await?;
|
||||
|
||||
match outcome {
|
||||
acp::ToolCallConfirmationOutcome::Allow
|
||||
| acp::ToolCallConfirmationOutcome::AlwaysAllow
|
||||
| acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
|
||||
| acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Allow,
|
||||
updated_input: params.input,
|
||||
}),
|
||||
acp::ToolCallConfirmationOutcome::Reject
|
||||
| acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Deny,
|
||||
updated_input: params.input,
|
||||
}),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
670
crates/agent_servers/src/claude/tools.rs
Normal file
670
crates/agent_servers/src/claude/tools.rs
Normal file
|
@ -0,0 +1,670 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use agentic_coding_protocol::{self as acp, PushToolCallParams, ToolCallLocation};
|
||||
use itertools::Itertools;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
|
||||
pub enum ClaudeTool {
|
||||
Task(Option<TaskToolParams>),
|
||||
NotebookRead(Option<NotebookReadToolParams>),
|
||||
NotebookEdit(Option<NotebookEditToolParams>),
|
||||
Edit(Option<EditToolParams>),
|
||||
MultiEdit(Option<MultiEditToolParams>),
|
||||
ReadFile(Option<ReadToolParams>),
|
||||
Write(Option<WriteToolParams>),
|
||||
Ls(Option<LsToolParams>),
|
||||
Glob(Option<GlobToolParams>),
|
||||
Grep(Option<GrepToolParams>),
|
||||
Terminal(Option<BashToolParams>),
|
||||
WebFetch(Option<WebFetchToolParams>),
|
||||
WebSearch(Option<WebSearchToolParams>),
|
||||
TodoWrite(Option<TodoWriteToolParams>),
|
||||
ExitPlanMode(Option<ExitPlanModeToolParams>),
|
||||
Other {
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
impl ClaudeTool {
|
||||
pub fn infer(tool_name: &str, input: serde_json::Value) -> Self {
|
||||
match tool_name {
|
||||
// Known tools
|
||||
"mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()),
|
||||
"mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()),
|
||||
"MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()),
|
||||
"Write" => Self::Write(serde_json::from_value(input).log_err()),
|
||||
"LS" => Self::Ls(serde_json::from_value(input).log_err()),
|
||||
"Glob" => Self::Glob(serde_json::from_value(input).log_err()),
|
||||
"Grep" => Self::Grep(serde_json::from_value(input).log_err()),
|
||||
"Bash" => Self::Terminal(serde_json::from_value(input).log_err()),
|
||||
"WebFetch" => Self::WebFetch(serde_json::from_value(input).log_err()),
|
||||
"WebSearch" => Self::WebSearch(serde_json::from_value(input).log_err()),
|
||||
"TodoWrite" => Self::TodoWrite(serde_json::from_value(input).log_err()),
|
||||
"exit_plan_mode" => Self::ExitPlanMode(serde_json::from_value(input).log_err()),
|
||||
"Task" => Self::Task(serde_json::from_value(input).log_err()),
|
||||
"NotebookRead" => Self::NotebookRead(serde_json::from_value(input).log_err()),
|
||||
"NotebookEdit" => Self::NotebookEdit(serde_json::from_value(input).log_err()),
|
||||
// Inferred from name
|
||||
_ => {
|
||||
let tool_name = tool_name.to_lowercase();
|
||||
|
||||
if tool_name.contains("edit") || tool_name.contains("write") {
|
||||
Self::Edit(None)
|
||||
} else if tool_name.contains("terminal") {
|
||||
Self::Terminal(None)
|
||||
} else {
|
||||
Self::Other {
|
||||
name: tool_name.to_string(),
|
||||
input,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn label(&self) -> String {
|
||||
match &self {
|
||||
Self::Task(Some(params)) => params.description.clone(),
|
||||
Self::Task(None) => "Task".into(),
|
||||
Self::NotebookRead(Some(params)) => {
|
||||
format!("Read Notebook {}", params.notebook_path.display())
|
||||
}
|
||||
Self::NotebookRead(None) => "Read Notebook".into(),
|
||||
Self::NotebookEdit(Some(params)) => {
|
||||
format!("Edit Notebook {}", params.notebook_path.display())
|
||||
}
|
||||
Self::NotebookEdit(None) => "Edit Notebook".into(),
|
||||
Self::Terminal(Some(params)) => format!("`{}`", params.command),
|
||||
Self::Terminal(None) => "Terminal".into(),
|
||||
Self::ReadFile(_) => "Read File".into(),
|
||||
Self::Ls(Some(params)) => {
|
||||
format!("List Directory {}", params.path.display())
|
||||
}
|
||||
Self::Ls(None) => "List Directory".into(),
|
||||
Self::Edit(Some(params)) => {
|
||||
format!("Edit {}", params.abs_path.display())
|
||||
}
|
||||
Self::Edit(None) => "Edit".into(),
|
||||
Self::MultiEdit(Some(params)) => {
|
||||
format!("Multi Edit {}", params.file_path.display())
|
||||
}
|
||||
Self::MultiEdit(None) => "Multi Edit".into(),
|
||||
Self::Write(Some(params)) => {
|
||||
format!("Write {}", params.file_path.display())
|
||||
}
|
||||
Self::Write(None) => "Write".into(),
|
||||
Self::Glob(Some(params)) => {
|
||||
format!("Glob {params}")
|
||||
}
|
||||
Self::Glob(None) => "Glob".into(),
|
||||
Self::Grep(Some(params)) => params.to_string(),
|
||||
Self::Grep(None) => "Grep".into(),
|
||||
Self::WebFetch(Some(params)) => format!("Fetch {}", params.url),
|
||||
Self::WebFetch(None) => "Fetch".into(),
|
||||
Self::WebSearch(Some(params)) => format!("Web Search: {}", params),
|
||||
Self::WebSearch(None) => "Web Search".into(),
|
||||
Self::TodoWrite(Some(params)) => format!(
|
||||
"Update TODOs: {}",
|
||||
params.todos.iter().map(|todo| &todo.content).join(", ")
|
||||
),
|
||||
Self::TodoWrite(None) => "Update TODOs".into(),
|
||||
Self::ExitPlanMode(_) => "Exit Plan Mode".into(),
|
||||
Self::Other { name, .. } => name.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn content(&self) -> Option<acp::ToolCallContent> {
|
||||
match &self {
|
||||
ClaudeTool::Other { input, .. } => Some(acp::ToolCallContent::Markdown {
|
||||
markdown: format!(
|
||||
"```json\n{}```",
|
||||
serde_json::to_string_pretty(&input).unwrap_or("{}".to_string())
|
||||
),
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn icon(&self) -> acp::Icon {
|
||||
match self {
|
||||
Self::Task(_) => acp::Icon::Hammer,
|
||||
Self::NotebookRead(_) => acp::Icon::FileSearch,
|
||||
Self::NotebookEdit(_) => acp::Icon::Pencil,
|
||||
Self::Edit(_) => acp::Icon::Pencil,
|
||||
Self::MultiEdit(_) => acp::Icon::Pencil,
|
||||
Self::Write(_) => acp::Icon::Pencil,
|
||||
Self::ReadFile(_) => acp::Icon::FileSearch,
|
||||
Self::Ls(_) => acp::Icon::Folder,
|
||||
Self::Glob(_) => acp::Icon::FileSearch,
|
||||
Self::Grep(_) => acp::Icon::Regex,
|
||||
Self::Terminal(_) => acp::Icon::Terminal,
|
||||
Self::WebSearch(_) => acp::Icon::Globe,
|
||||
Self::WebFetch(_) => acp::Icon::Globe,
|
||||
Self::TodoWrite(_) => acp::Icon::LightBulb,
|
||||
Self::ExitPlanMode(_) => acp::Icon::Hammer,
|
||||
Self::Other { .. } => acp::Icon::Hammer,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn confirmation(&self, description: Option<String>) -> acp::ToolCallConfirmation {
|
||||
match &self {
|
||||
Self::Edit(_) | Self::Write(_) | Self::NotebookEdit(_) | Self::MultiEdit(_) => {
|
||||
acp::ToolCallConfirmation::Edit { description }
|
||||
}
|
||||
Self::WebFetch(params) => acp::ToolCallConfirmation::Fetch {
|
||||
urls: params
|
||||
.as_ref()
|
||||
.map(|p| vec![p.url.clone()])
|
||||
.unwrap_or_default(),
|
||||
description,
|
||||
},
|
||||
Self::Terminal(Some(BashToolParams {
|
||||
description,
|
||||
command,
|
||||
..
|
||||
})) => acp::ToolCallConfirmation::Execute {
|
||||
command: command.clone(),
|
||||
root_command: command.clone(),
|
||||
description: description.clone(),
|
||||
},
|
||||
Self::ExitPlanMode(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {}", params.plan)
|
||||
} else {
|
||||
params.plan.clone()
|
||||
},
|
||||
},
|
||||
Self::Task(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {}", params.description)
|
||||
} else {
|
||||
params.description.clone()
|
||||
},
|
||||
},
|
||||
Self::Ls(Some(LsToolParams { path, .. }))
|
||||
| Self::ReadFile(Some(ReadToolParams { abs_path: path, .. })) => {
|
||||
let path = path.display();
|
||||
acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {path}")
|
||||
} else {
|
||||
path.to_string()
|
||||
},
|
||||
}
|
||||
}
|
||||
Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
|
||||
let path = notebook_path.display();
|
||||
acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {path}")
|
||||
} else {
|
||||
path.to_string()
|
||||
},
|
||||
}
|
||||
}
|
||||
Self::Glob(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {params}")
|
||||
} else {
|
||||
params.to_string()
|
||||
},
|
||||
},
|
||||
Self::Grep(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {params}")
|
||||
} else {
|
||||
params.to_string()
|
||||
},
|
||||
},
|
||||
Self::WebSearch(Some(params)) => acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {params}")
|
||||
} else {
|
||||
params.to_string()
|
||||
},
|
||||
},
|
||||
Self::TodoWrite(Some(params)) => {
|
||||
let params = params.todos.iter().map(|todo| &todo.content).join(", ");
|
||||
acp::ToolCallConfirmation::Other {
|
||||
description: if let Some(description) = description {
|
||||
format!("{description} {params}")
|
||||
} else {
|
||||
params
|
||||
},
|
||||
}
|
||||
}
|
||||
Self::Terminal(None)
|
||||
| Self::Task(None)
|
||||
| Self::NotebookRead(None)
|
||||
| Self::ExitPlanMode(None)
|
||||
| Self::Ls(None)
|
||||
| Self::Glob(None)
|
||||
| Self::Grep(None)
|
||||
| Self::ReadFile(None)
|
||||
| Self::WebSearch(None)
|
||||
| Self::TodoWrite(None)
|
||||
| Self::Other { .. } => acp::ToolCallConfirmation::Other {
|
||||
description: description.unwrap_or("".to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn locations(&self) -> Vec<acp::ToolCallLocation> {
|
||||
match &self {
|
||||
Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![ToolCallLocation {
|
||||
path: abs_path.clone(),
|
||||
line: None,
|
||||
}],
|
||||
Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => {
|
||||
vec![ToolCallLocation {
|
||||
path: file_path.clone(),
|
||||
line: None,
|
||||
}]
|
||||
}
|
||||
Self::Write(Some(WriteToolParams { file_path, .. })) => vec![ToolCallLocation {
|
||||
path: file_path.clone(),
|
||||
line: None,
|
||||
}],
|
||||
Self::ReadFile(Some(ReadToolParams {
|
||||
abs_path, offset, ..
|
||||
})) => vec![ToolCallLocation {
|
||||
path: abs_path.clone(),
|
||||
line: *offset,
|
||||
}],
|
||||
Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => {
|
||||
vec![ToolCallLocation {
|
||||
path: notebook_path.clone(),
|
||||
line: None,
|
||||
}]
|
||||
}
|
||||
Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => {
|
||||
vec![ToolCallLocation {
|
||||
path: notebook_path.clone(),
|
||||
line: None,
|
||||
}]
|
||||
}
|
||||
Self::Glob(Some(GlobToolParams {
|
||||
path: Some(path), ..
|
||||
})) => vec![ToolCallLocation {
|
||||
path: path.clone(),
|
||||
line: None,
|
||||
}],
|
||||
Self::Ls(Some(LsToolParams { path, .. })) => vec![ToolCallLocation {
|
||||
path: path.clone(),
|
||||
line: None,
|
||||
}],
|
||||
Self::Grep(Some(GrepToolParams {
|
||||
path: Some(path), ..
|
||||
})) => vec![ToolCallLocation {
|
||||
path: PathBuf::from(path),
|
||||
line: None,
|
||||
}],
|
||||
Self::Task(_)
|
||||
| Self::NotebookRead(None)
|
||||
| Self::NotebookEdit(None)
|
||||
| Self::Edit(None)
|
||||
| Self::MultiEdit(None)
|
||||
| Self::Write(None)
|
||||
| Self::ReadFile(None)
|
||||
| Self::Ls(None)
|
||||
| Self::Glob(_)
|
||||
| Self::Grep(_)
|
||||
| Self::Terminal(_)
|
||||
| Self::WebFetch(_)
|
||||
| Self::WebSearch(_)
|
||||
| Self::TodoWrite(_)
|
||||
| Self::ExitPlanMode(_)
|
||||
| Self::Other { .. } => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_acp(&self) -> PushToolCallParams {
|
||||
PushToolCallParams {
|
||||
label: self.label(),
|
||||
content: self.content(),
|
||||
icon: self.icon(),
|
||||
locations: self.locations(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct EditToolParams {
|
||||
/// The absolute path to the file to read.
|
||||
pub abs_path: PathBuf,
|
||||
/// The old text to replace (must be unique in the file)
|
||||
pub old_text: String,
|
||||
/// The new text.
|
||||
pub new_text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct EditToolResponse;
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct ReadToolParams {
|
||||
/// The absolute path to the file to read.
|
||||
pub abs_path: PathBuf,
|
||||
/// Which line to start reading from. Omit to start from the beginning.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub offset: Option<u32>,
|
||||
/// How many lines to read. Omit for the whole file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ReadToolResponse {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct WriteToolParams {
|
||||
/// Absolute path for new file
|
||||
pub file_path: PathBuf,
|
||||
/// File content
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct BashToolParams {
|
||||
/// Shell command to execute
|
||||
pub command: String,
|
||||
/// 5-10 word description of what command does
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
/// Timeout in ms (max 600000ms/10min, default 120000ms)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timeout: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct GlobToolParams {
|
||||
/// Glob pattern like **/*.js or src/**/*.ts
|
||||
pub pattern: String,
|
||||
/// Directory to search in (omit for current directory)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for GlobToolParams {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if let Some(path) = &self.path {
|
||||
write!(f, "{}", path.display())?;
|
||||
}
|
||||
write!(f, "{}", self.pattern)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct LsToolParams {
|
||||
/// Absolute path to directory
|
||||
pub path: PathBuf,
|
||||
/// Array of glob patterns to ignore
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub ignore: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct GrepToolParams {
|
||||
/// Regex pattern to search for
|
||||
pub pattern: String,
|
||||
/// File/directory to search (defaults to current directory)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub path: Option<String>,
|
||||
/// "content" (shows lines), "files_with_matches" (default), "count"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_mode: Option<GrepOutputMode>,
|
||||
/// Filter files with glob pattern like "*.js"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub glob: Option<String>,
|
||||
/// File type filter like "js", "py", "rust"
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
pub file_type: Option<String>,
|
||||
/// Case insensitive search
|
||||
#[serde(rename = "-i", default, skip_serializing_if = "is_false")]
|
||||
pub case_insensitive: bool,
|
||||
/// Show line numbers (content mode only)
|
||||
#[serde(rename = "-n", default, skip_serializing_if = "is_false")]
|
||||
pub line_numbers: bool,
|
||||
/// Lines after match (content mode only)
|
||||
#[serde(rename = "-A", skip_serializing_if = "Option::is_none")]
|
||||
pub after_context: Option<u32>,
|
||||
/// Lines before match (content mode only)
|
||||
#[serde(rename = "-B", skip_serializing_if = "Option::is_none")]
|
||||
pub before_context: Option<u32>,
|
||||
/// Lines before and after match (content mode only)
|
||||
#[serde(rename = "-C", skip_serializing_if = "Option::is_none")]
|
||||
pub context: Option<u32>,
|
||||
/// Enable multiline/cross-line matching
|
||||
#[serde(default, skip_serializing_if = "is_false")]
|
||||
pub multiline: bool,
|
||||
/// Limit output to first N results
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub head_limit: Option<u32>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for GrepToolParams {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "grep")?;
|
||||
|
||||
// Boolean flags
|
||||
if self.case_insensitive {
|
||||
write!(f, " -i")?;
|
||||
}
|
||||
if self.line_numbers {
|
||||
write!(f, " -n")?;
|
||||
}
|
||||
|
||||
// Context options
|
||||
if let Some(after) = self.after_context {
|
||||
write!(f, " -A {}", after)?;
|
||||
}
|
||||
if let Some(before) = self.before_context {
|
||||
write!(f, " -B {}", before)?;
|
||||
}
|
||||
if let Some(context) = self.context {
|
||||
write!(f, " -C {}", context)?;
|
||||
}
|
||||
|
||||
// Output mode
|
||||
if let Some(mode) = &self.output_mode {
|
||||
match mode {
|
||||
GrepOutputMode::FilesWithMatches => write!(f, " -l")?,
|
||||
GrepOutputMode::Count => write!(f, " -c")?,
|
||||
GrepOutputMode::Content => {} // Default mode
|
||||
}
|
||||
}
|
||||
|
||||
// Head limit
|
||||
if let Some(limit) = self.head_limit {
|
||||
write!(f, " | head -{}", limit)?;
|
||||
}
|
||||
|
||||
// Glob pattern
|
||||
if let Some(glob) = &self.glob {
|
||||
write!(f, " --include=\"{}\"", glob)?;
|
||||
}
|
||||
|
||||
// File type
|
||||
if let Some(file_type) = &self.file_type {
|
||||
write!(f, " --type={}", file_type)?;
|
||||
}
|
||||
|
||||
// Multiline
|
||||
if self.multiline {
|
||||
write!(f, " -P")?; // Perl-compatible regex for multiline
|
||||
}
|
||||
|
||||
// Pattern (escaped if contains special characters)
|
||||
write!(f, " \"{}\"", self.pattern)?;
|
||||
|
||||
// Path
|
||||
if let Some(path) = &self.path {
|
||||
write!(f, " {}", path)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TodoPriority {
|
||||
High,
|
||||
Medium,
|
||||
Low,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TodoStatus {
|
||||
Pending,
|
||||
InProgress,
|
||||
Completed,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
||||
pub struct Todo {
|
||||
/// Unique identifier
|
||||
pub id: String,
|
||||
/// Task description
|
||||
pub content: String,
|
||||
/// Priority level of the todo
|
||||
pub priority: TodoPriority,
|
||||
/// Current status of the todo
|
||||
pub status: TodoStatus,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct TodoWriteToolParams {
|
||||
pub todos: Vec<Todo>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct ExitPlanModeToolParams {
|
||||
/// Implementation plan in markdown format
|
||||
pub plan: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct TaskToolParams {
|
||||
/// Short 3-5 word description of task
|
||||
pub description: String,
|
||||
/// Detailed task for agent to perform
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct NotebookReadToolParams {
|
||||
/// Absolute path to .ipynb file
|
||||
pub notebook_path: PathBuf,
|
||||
/// Specific cell ID to read
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cell_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CellType {
|
||||
Code,
|
||||
Markdown,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum EditMode {
|
||||
Replace,
|
||||
Insert,
|
||||
Delete,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct NotebookEditToolParams {
|
||||
/// Absolute path to .ipynb file
|
||||
pub notebook_path: PathBuf,
|
||||
/// New cell content
|
||||
pub new_source: String,
|
||||
/// Cell ID to edit
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cell_id: Option<String>,
|
||||
/// Type of cell (code or markdown)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cell_type: Option<CellType>,
|
||||
/// Edit operation mode
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub edit_mode: Option<EditMode>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
||||
pub struct MultiEditItem {
|
||||
/// The text to search for and replace
|
||||
pub old_string: String,
|
||||
/// The replacement text
|
||||
pub new_string: String,
|
||||
/// Whether to replace all occurrences or just the first
|
||||
#[serde(default, skip_serializing_if = "is_false")]
|
||||
pub replace_all: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct MultiEditToolParams {
|
||||
/// Absolute path to file
|
||||
pub file_path: PathBuf,
|
||||
/// List of edits to apply
|
||||
pub edits: Vec<MultiEditItem>,
|
||||
}
|
||||
|
||||
fn is_false(v: &bool) -> bool {
|
||||
!*v
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum GrepOutputMode {
|
||||
Content,
|
||||
FilesWithMatches,
|
||||
Count,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct WebFetchToolParams {
|
||||
/// Valid URL to fetch
|
||||
#[serde(rename = "url")]
|
||||
pub url: String,
|
||||
/// What to extract from content
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct WebSearchToolParams {
|
||||
/// Search query (min 2 chars)
|
||||
pub query: String,
|
||||
/// Only include these domains
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub allowed_domains: Vec<String>,
|
||||
/// Exclude these domains
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub blocked_domains: Vec<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WebSearchToolParams {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "\"{}\"", self.query)?;
|
||||
|
||||
if !self.allowed_domains.is_empty() {
|
||||
write!(f, " (allowed: {})", self.allowed_domains.join(", "))?;
|
||||
}
|
||||
|
||||
if !self.blocked_domains.is_empty() {
|
||||
write!(f, " (blocked: {})", self.blocked_domains.join(", "))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
501
crates/agent_servers/src/gemini.rs
Normal file
501
crates/agent_servers/src/gemini.rs
Normal file
|
@ -0,0 +1,501 @@
|
|||
use crate::stdio_agent_server::{StdioAgentServer, find_bin_in_path};
|
||||
use crate::{AgentServerCommand, AgentServerVersion};
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
|
||||
use crate::AllAgentServersSettings;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Gemini;
|
||||
|
||||
const ACP_ARG: &str = "--experimental-acp";
|
||||
|
||||
impl StdioAgentServer for Gemini {
|
||||
fn name(&self) -> &'static str {
|
||||
"Gemini"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"Welcome to Gemini"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"Ask questions, edit files, run commands.\nBe specific for the best results."
|
||||
}
|
||||
|
||||
fn supports_always_allow(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiGemini
|
||||
}
|
||||
|
||||
async fn command(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<AgentServerCommand> {
|
||||
let custom_command = cx.read_global(|settings: &SettingsStore, _| {
|
||||
let settings = settings.get::<AllAgentServersSettings>(None);
|
||||
settings
|
||||
.gemini
|
||||
.as_ref()
|
||||
.map(|gemini_settings| AgentServerCommand {
|
||||
path: gemini_settings.command.path.clone(),
|
||||
args: gemini_settings
|
||||
.command
|
||||
.args
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(std::iter::once(ACP_ARG.into()))
|
||||
.collect(),
|
||||
env: gemini_settings.command.env.clone(),
|
||||
})
|
||||
})?;
|
||||
|
||||
if let Some(custom_command) = custom_command {
|
||||
return Ok(custom_command);
|
||||
}
|
||||
|
||||
if let Some(path) = find_bin_in_path("gemini", project, cx).await {
|
||||
return Ok(AgentServerCommand {
|
||||
path,
|
||||
args: vec![ACP_ARG.into()],
|
||||
env: None,
|
||||
});
|
||||
}
|
||||
|
||||
let (fs, node_runtime) = project.update(cx, |project, _| {
|
||||
(project.fs().clone(), project.node_runtime().cloned())
|
||||
})?;
|
||||
let node_runtime = node_runtime.context("gemini not found on path")?;
|
||||
|
||||
let directory = ::paths::agent_servers_dir().join("gemini");
|
||||
fs.create_dir(&directory).await?;
|
||||
node_runtime
|
||||
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
|
||||
.await?;
|
||||
let path = directory.join("node_modules/.bin/gemini");
|
||||
|
||||
Ok(AgentServerCommand {
|
||||
path,
|
||||
args: vec![ACP_ARG.into()],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
|
||||
let version_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--version")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let help_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--help")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
|
||||
|
||||
let current_version = String::from_utf8(version_output?.stdout)?;
|
||||
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
|
||||
|
||||
if supported {
|
||||
Ok(AgentServerVersion::Supported)
|
||||
} else {
|
||||
Ok(AgentServerVersion::Unsupported {
|
||||
error_message: format!(
|
||||
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
|
||||
current_version
|
||||
).into(),
|
||||
upgrade_message: "Upgrade Gemini to Latest".into(),
|
||||
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::{path::Path, time::Duration};
|
||||
|
||||
use acp_thread::{
|
||||
AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent,
|
||||
ToolCallStatus,
|
||||
};
|
||||
use agentic_coding_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||
use gpui::{AsyncApp, Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion, StdioAgentServer};
|
||||
|
||||
pub async fn gemini_acp_thread(
|
||||
project: Entity<Project>,
|
||||
current_dir: impl AsRef<Path>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Entity<AcpThread> {
|
||||
#[derive(Clone)]
|
||||
struct DevGemini;
|
||||
|
||||
impl StdioAgentServer for DevGemini {
|
||||
async fn command(
|
||||
&self,
|
||||
_project: &Entity<Project>,
|
||||
_cx: &mut AsyncApp,
|
||||
) -> Result<AgentServerCommand> {
|
||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../../gemini-cli/packages/cli")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
Ok(AgentServerCommand {
|
||||
path: "node".into(),
|
||||
args: vec![cli_path, "--experimental-acp".into()],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn version(&self, _command: &AgentServerCommand) -> Result<AgentServerVersion> {
|
||||
Ok(AgentServerVersion::Supported)
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiGemini
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
fn supports_always_allow(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
let thread = cx
|
||||
.update(|cx| AgentServer::new_thread(&DevGemini, current_dir.as_ref(), &project, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _| thread.initialize())
|
||||
.await
|
||||
.unwrap();
|
||||
thread
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_basic(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.entries().len(), 2);
|
||||
assert!(matches!(
|
||||
thread.entries()[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
thread.entries()[1],
|
||||
AgentThreadEntry::AssistantMessage(_)
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
std::fs::write(
|
||||
tempdir.path().join("foo.rs"),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
println!(\"Hello, world!\");
|
||||
}
|
||||
"},
|
||||
)
|
||||
.expect("failed to write file");
|
||||
let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
|
||||
let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
acp::SendUserMessageParams {
|
||||
chunks: vec![
|
||||
acp::UserMessageChunk::Text {
|
||||
text: "Read the file ".into(),
|
||||
},
|
||||
acp::UserMessageChunk::Path {
|
||||
path: Path::new("foo.rs").into(),
|
||||
},
|
||||
acp::UserMessageChunk::Text {
|
||||
text: " and tell me what the content of the println! is".into(),
|
||||
},
|
||||
],
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries().len(), 3);
|
||||
assert!(matches!(
|
||||
thread.entries()[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
));
|
||||
assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
|
||||
let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
|
||||
panic!("Expected AssistantMessage")
|
||||
};
|
||||
assert!(
|
||||
assistant_message.to_markdown(cx).contains("Hello, world!"),
|
||||
"unexpected assistant message: {:?}",
|
||||
assistant_message.to_markdown(cx)
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_tool_call(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/private/tmp"),
|
||||
json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Read the '/private/tmp/foo' file and tell me what you see.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert!(matches!(
|
||||
&thread.entries()[2],
|
||||
AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
|
||||
assert!(matches!(
|
||||
thread.entries()[3],
|
||||
AgentThreadEntry::AssistantMessage(_)
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
|
||||
});
|
||||
|
||||
run_until_first_tool_call(&thread, cx).await;
|
||||
|
||||
let tool_call_id = thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[2]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
assert_eq!(root_command, "echo");
|
||||
|
||||
*id
|
||||
});
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
|
||||
|
||||
assert!(matches!(
|
||||
&thread.entries()[2],
|
||||
AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
});
|
||||
|
||||
full_turn.await.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
content: Some(ToolCallContent::Markdown { markdown }),
|
||||
status: ToolCallStatus::Allowed { .. },
|
||||
..
|
||||
}) = &thread.entries()[2]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
|
||||
markdown.read_with(cx, |md, _cx| {
|
||||
assert!(
|
||||
md.source().contains("Hello, world!"),
|
||||
r#"Expected '{}' to contain "Hello, world!""#,
|
||||
md.source()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "gemini"), ignore)]
|
||||
async fn test_gemini_cancel(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
|
||||
});
|
||||
|
||||
let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
|
||||
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
status:
|
||||
ToolCallStatus::WaitingForConfirmation {
|
||||
confirmation: ToolCallConfirmation::Execute { root_command, .. },
|
||||
..
|
||||
},
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
else {
|
||||
panic!("{:?}", thread.entries()[1]);
|
||||
};
|
||||
|
||||
assert_eq!(root_command, "echo");
|
||||
|
||||
*id
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| thread.cancel(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
full_turn.await.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Canceled,
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
else {
|
||||
panic!();
|
||||
};
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(
|
||||
&thread.entries().last().unwrap(),
|
||||
AgentThreadEntry::AssistantMessage(..),
|
||||
))
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_until_first_tool_call(
|
||||
thread: &Entity<AcpThread>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> usize {
|
||||
let (mut tx, mut rx) = mpsc::channel::<usize>(1);
|
||||
|
||||
let subscription = cx.update(|cx| {
|
||||
cx.subscribe(thread, move |thread, _, cx| {
|
||||
for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
|
||||
if matches!(entry, AgentThreadEntry::ToolCall(_)) {
|
||||
return tx.try_send(ix).unwrap();
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
select! {
|
||||
_ = cx.executor().timer(Duration::from_secs(10)).fuse() => {
|
||||
panic!("Timeout waiting for tool call")
|
||||
}
|
||||
ix = rx.next().fuse() => {
|
||||
drop(subscription);
|
||||
ix.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
41
crates/agent_servers/src/settings.rs
Normal file
41
crates/agent_servers/src/settings.rs
Normal file
|
@ -0,0 +1,41 @@
|
|||
use crate::AgentServerCommand;
|
||||
use anyhow::Result;
|
||||
use gpui::App;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
AllAgentServersSettings::register(cx);
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
pub struct AllAgentServersSettings {
|
||||
pub gemini: Option<AgentServerSettings>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
pub struct AgentServerSettings {
|
||||
#[serde(flatten)]
|
||||
pub command: AgentServerCommand,
|
||||
}
|
||||
|
||||
impl settings::Settings for AllAgentServersSettings {
|
||||
const KEY: Option<&'static str> = Some("agent_servers");
|
||||
|
||||
type FileContent = Self;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
let mut settings = AllAgentServersSettings::default();
|
||||
|
||||
for value in sources.defaults_and_customizations() {
|
||||
if value.gemini.is_some() {
|
||||
settings.gemini = value.gemini.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
||||
}
|
169
crates/agent_servers/src/stdio_agent_server.rs
Normal file
169
crates/agent_servers/src/stdio_agent_server.rs
Normal file
|
@ -0,0 +1,169 @@
|
|||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
|
||||
use acp_thread::{AcpClientDelegate, AcpThread, LoadError};
|
||||
use agentic_coding_protocol as acp;
|
||||
use anyhow::{Result, anyhow};
|
||||
use gpui::{App, AsyncApp, Entity, Task, prelude::*};
|
||||
use project::Project;
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{ResultExt, paths};
|
||||
|
||||
pub trait StdioAgentServer: Send + Clone {
|
||||
fn logo(&self) -> ui::IconName;
|
||||
fn name(&self) -> &'static str;
|
||||
fn empty_state_headline(&self) -> &'static str;
|
||||
fn empty_state_message(&self) -> &'static str;
|
||||
fn supports_always_allow(&self) -> bool;
|
||||
|
||||
fn command(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> impl Future<Output = Result<AgentServerCommand>>;
|
||||
|
||||
fn version(
|
||||
&self,
|
||||
command: &AgentServerCommand,
|
||||
) -> impl Future<Output = Result<AgentServerVersion>> + Send;
|
||||
}
|
||||
|
||||
impl<T: StdioAgentServer + 'static> AgentServer for T {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name()
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
self.empty_state_headline()
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
self.empty_state_message()
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
self.logo()
|
||||
}
|
||||
|
||||
fn supports_always_allow(&self) -> bool {
|
||||
self.supports_always_allow()
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
&self,
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
let project = project.clone();
|
||||
let this = self.clone();
|
||||
let title = self.name().into();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let command = this.command(&project, cx).await?;
|
||||
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
cx.new(|cx| {
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
|
||||
let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
|
||||
AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
|
||||
stdin,
|
||||
stdout,
|
||||
move |fut| foreground_executor.spawn(fut).detach(),
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn(async move {
|
||||
io_fut.await.log_err();
|
||||
});
|
||||
|
||||
let child_status = cx.background_spawn(async move {
|
||||
let result = match child.status().await {
|
||||
Err(e) => Err(anyhow!(e)),
|
||||
Ok(result) if result.success() => Ok(()),
|
||||
Ok(result) => {
|
||||
if let Some(AgentServerVersion::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command,
|
||||
}) = this.version(&command).await.log_err()
|
||||
{
|
||||
Err(anyhow!(LoadError::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command
|
||||
}))
|
||||
} else {
|
||||
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
|
||||
}
|
||||
}
|
||||
};
|
||||
drop(io_task);
|
||||
result
|
||||
});
|
||||
|
||||
AcpThread::new(connection, title, Some(child_status), project.clone(), cx)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn find_bin_in_path(
|
||||
bin_name: &'static str,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Option<PathBuf> {
|
||||
let (env_task, root_dir) = project
|
||||
.update(cx, |project, cx| {
|
||||
let worktree = project.visible_worktrees(cx).next();
|
||||
match worktree {
|
||||
Some(worktree) => {
|
||||
let env_task = project.environment().update(cx, |env, cx| {
|
||||
env.get_worktree_environment(worktree.clone(), cx)
|
||||
});
|
||||
|
||||
let path = worktree.read(cx).abs_path();
|
||||
(env_task, path)
|
||||
}
|
||||
None => {
|
||||
let path: Arc<Path> = paths::home_dir().as_path().into();
|
||||
let env_task = project.environment().update(cx, |env, cx| {
|
||||
env.get_directory_environment(path.clone(), cx)
|
||||
});
|
||||
(env_task, path)
|
||||
}
|
||||
}
|
||||
})
|
||||
.log_err()?;
|
||||
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let which_result = if cfg!(windows) {
|
||||
which::which(bin_name)
|
||||
} else {
|
||||
let env = env_task.await.unwrap_or_default();
|
||||
let shell_path = env.get("PATH").cloned();
|
||||
which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref())
|
||||
};
|
||||
|
||||
if let Err(which::Error::CannotFindBinaryPath) = which_result {
|
||||
return None;
|
||||
}
|
||||
|
||||
which_result.log_err()
|
||||
})
|
||||
.await
|
||||
}
|
|
@ -16,7 +16,7 @@ doctest = false
|
|||
test-support = ["gpui/test-support", "language/test-support"]
|
||||
|
||||
[dependencies]
|
||||
acp.workspace = true
|
||||
acp_thread.workspace = true
|
||||
agent.workspace = true
|
||||
agentic-coding-protocol.workspace = true
|
||||
agent_settings.workspace = true
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use agent_servers::AgentServer;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
|
@ -35,7 +36,7 @@ use util::ResultExt;
|
|||
use workspace::{CollaboratorId, Workspace};
|
||||
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
|
||||
|
||||
use ::acp::{
|
||||
use ::acp_thread::{
|
||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff,
|
||||
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallConfirmation, ToolCallContent,
|
||||
ToolCallId, ToolCallStatus,
|
||||
|
@ -49,6 +50,7 @@ use crate::{AgentDiffPane, Follow, KeepAll, OpenAgentDiff, RejectAll};
|
|||
const RESPONSE_PADDING_X: Pixels = px(19.);
|
||||
|
||||
pub struct AcpThreadView {
|
||||
agent: Rc<dyn AgentServer>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
thread_state: ThreadState,
|
||||
|
@ -80,8 +82,15 @@ enum ThreadState {
|
|||
},
|
||||
}
|
||||
|
||||
struct AlwaysAllowOption {
|
||||
id: &'static str,
|
||||
label: SharedString,
|
||||
outcome: acp::ToolCallConfirmationOutcome,
|
||||
}
|
||||
|
||||
impl AcpThreadView {
|
||||
pub fn new(
|
||||
agent: Rc<dyn AgentServer>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
message_history: Rc<RefCell<MessageHistory<acp::SendUserMessageParams>>>,
|
||||
|
@ -158,9 +167,10 @@ impl AcpThreadView {
|
|||
);
|
||||
|
||||
Self {
|
||||
agent: agent.clone(),
|
||||
workspace: workspace.clone(),
|
||||
project: project.clone(),
|
||||
thread_state: Self::initial_state(workspace, project, window, cx),
|
||||
thread_state: Self::initial_state(agent, workspace, project, window, cx),
|
||||
message_editor,
|
||||
message_set_from_history: false,
|
||||
_message_editor_subscription: message_editor_subscription,
|
||||
|
@ -177,6 +187,7 @@ impl AcpThreadView {
|
|||
}
|
||||
|
||||
fn initial_state(
|
||||
agent: Rc<dyn AgentServer>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
window: &mut Window,
|
||||
|
@ -189,9 +200,9 @@ impl AcpThreadView {
|
|||
.map(|worktree| worktree.read(cx).abs_path())
|
||||
.unwrap_or_else(|| paths::home_dir().as_path().into());
|
||||
|
||||
let task = agent.new_thread(&root_dir, &project, cx);
|
||||
let load_task = cx.spawn_in(window, async move |this, cx| {
|
||||
let thread = match AcpThread::spawn(agent_servers::Gemini, &root_dir, project, cx).await
|
||||
{
|
||||
let thread = match task.await {
|
||||
Ok(thread) => thread,
|
||||
Err(err) => {
|
||||
this.update(cx, |this, cx| {
|
||||
|
@ -410,6 +421,33 @@ impl AcpThreadView {
|
|||
);
|
||||
}
|
||||
|
||||
fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if let Some(thread) = self.thread() {
|
||||
AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err();
|
||||
}
|
||||
}
|
||||
|
||||
fn open_edited_buffer(
|
||||
&mut self,
|
||||
buffer: &Entity<Buffer>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(thread) = self.thread() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(diff) =
|
||||
AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
diff.update(cx, |diff, cx| {
|
||||
diff.move_to_path(PathKey::for_buffer(&buffer, cx), window, cx)
|
||||
})
|
||||
}
|
||||
|
||||
fn set_draft_message(
|
||||
message_editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
|
@ -485,33 +523,6 @@ impl AcpThreadView {
|
|||
true
|
||||
}
|
||||
|
||||
fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if let Some(thread) = self.thread() {
|
||||
AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err();
|
||||
}
|
||||
}
|
||||
|
||||
fn open_edited_buffer(
|
||||
&mut self,
|
||||
buffer: &Entity<Buffer>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(thread) = self.thread() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(diff) =
|
||||
AgentDiffPane::deploy(thread.clone(), self.workspace.clone(), window, cx).log_err()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
diff.update(cx, |diff, cx| {
|
||||
diff.move_to_path(PathKey::for_buffer(&buffer, cx), window, cx)
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_thread_event(
|
||||
&mut self,
|
||||
thread: &Entity<AcpThread>,
|
||||
|
@ -608,6 +619,7 @@ impl AcpThreadView {
|
|||
let authenticate = thread.read(cx).authenticate();
|
||||
self.auth_task = Some(cx.spawn_in(window, {
|
||||
let project = self.project.clone();
|
||||
let agent = self.agent.clone();
|
||||
async move |this, cx| {
|
||||
let result = authenticate.await;
|
||||
|
||||
|
@ -617,8 +629,13 @@ impl AcpThreadView {
|
|||
Markdown::new(format!("Error: {err}").into(), None, None, cx)
|
||||
}))
|
||||
} else {
|
||||
this.thread_state =
|
||||
Self::initial_state(this.workspace.clone(), project.clone(), window, cx)
|
||||
this.thread_state = Self::initial_state(
|
||||
agent,
|
||||
this.workspace.clone(),
|
||||
project.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
this.auth_task.take()
|
||||
})
|
||||
|
@ -1047,14 +1064,6 @@ impl AcpThreadView {
|
|||
) -> AnyElement {
|
||||
let confirmation_container = v_flex().mt_1().py_1p5();
|
||||
|
||||
let button_container = h_flex()
|
||||
.pt_1p5()
|
||||
.px_1p5()
|
||||
.gap_1()
|
||||
.justify_end()
|
||||
.border_t_1()
|
||||
.border_color(self.tool_card_border_color(cx));
|
||||
|
||||
match confirmation {
|
||||
ToolCallConfirmation::Edit { description } => confirmation_container
|
||||
.child(
|
||||
|
@ -1068,60 +1077,15 @@ impl AcpThreadView {
|
|||
})),
|
||||
)
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(
|
||||
button_container
|
||||
.child(
|
||||
Button::new(("always_allow", tool_call_id.0), "Always Allow Edits")
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.0), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.0), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Error)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[AlwaysAllowOption {
|
||||
id: "always_allow",
|
||||
label: "Always Allow Edits".into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
}],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Execute {
|
||||
command,
|
||||
|
@ -1140,66 +1104,15 @@ impl AcpThreadView {
|
|||
}),
|
||||
))
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(
|
||||
button_container
|
||||
.child(
|
||||
Button::new(
|
||||
("always_allow", tool_call_id.0),
|
||||
format!("Always Allow {root_command}"),
|
||||
)
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.0), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.0), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Error)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[AlwaysAllowOption {
|
||||
id: "always_allow",
|
||||
label: format!("Always Allow {root_command}").into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
}],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Mcp {
|
||||
server_name,
|
||||
|
@ -1220,87 +1133,22 @@ impl AcpThreadView {
|
|||
})),
|
||||
)
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(
|
||||
button_container
|
||||
.child(
|
||||
Button::new(
|
||||
("always_allow_server", tool_call_id.0),
|
||||
format!("Always Allow {server_name}"),
|
||||
)
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(
|
||||
("always_allow_tool", tool_call_id.0),
|
||||
format!("Always Allow {tool_display_name}"),
|
||||
)
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllowTool,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.0), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.0), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Error)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[
|
||||
AlwaysAllowOption {
|
||||
id: "always_allow_server",
|
||||
label: format!("Always Allow {server_name}").into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
|
||||
},
|
||||
AlwaysAllowOption {
|
||||
id: "always_allow_tool",
|
||||
label: format!("Always Allow {tool_display_name}").into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllowTool,
|
||||
},
|
||||
],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Fetch { description, urls } => confirmation_container
|
||||
.child(
|
||||
|
@ -1328,63 +1176,15 @@ impl AcpThreadView {
|
|||
})),
|
||||
)
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(
|
||||
button_container
|
||||
.child(
|
||||
Button::new(("always_allow", tool_call_id.0), "Always Allow")
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.0), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.0), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Error)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[AlwaysAllowOption {
|
||||
id: "always_allow",
|
||||
label: "Always Allow".into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
}],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
ToolCallConfirmation::Other { description } => confirmation_container
|
||||
.child(v_flex().px_2().pb_1p5().child(self.render_markdown(
|
||||
|
@ -1392,67 +1192,87 @@ impl AcpThreadView {
|
|||
default_markdown_style(false, window, cx),
|
||||
)))
|
||||
.children(content.map(|content| self.render_tool_call_content(content, window, cx)))
|
||||
.child(
|
||||
button_container
|
||||
.child(
|
||||
Button::new(("always_allow", tool_call_id.0), "Always Allow")
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.0), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.0), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Error)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
.child(self.render_confirmation_buttons(
|
||||
&[AlwaysAllowOption {
|
||||
id: "always_allow",
|
||||
label: "Always Allow".into(),
|
||||
outcome: acp::ToolCallConfirmationOutcome::AlwaysAllow,
|
||||
}],
|
||||
tool_call_id,
|
||||
cx,
|
||||
))
|
||||
.into_any(),
|
||||
}
|
||||
}
|
||||
|
||||
fn render_confirmation_buttons(
|
||||
&self,
|
||||
always_allow_options: &[AlwaysAllowOption],
|
||||
tool_call_id: ToolCallId,
|
||||
cx: &Context<Self>,
|
||||
) -> Div {
|
||||
h_flex()
|
||||
.pt_1p5()
|
||||
.px_1p5()
|
||||
.gap_1()
|
||||
.justify_end()
|
||||
.border_t_1()
|
||||
.border_color(self.tool_card_border_color(cx))
|
||||
.when(self.agent.supports_always_allow(), |this| {
|
||||
this.children(always_allow_options.into_iter().map(|always_allow_option| {
|
||||
let outcome = always_allow_option.outcome;
|
||||
Button::new(
|
||||
(always_allow_option.id, tool_call_id.0),
|
||||
always_allow_option.label.clone(),
|
||||
)
|
||||
.icon(IconName::CheckDouble)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(id, outcome, cx);
|
||||
}
|
||||
}))
|
||||
}))
|
||||
})
|
||||
.child(
|
||||
Button::new(("allow", tool_call_id.0), "Allow")
|
||||
.icon(IconName::Check)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Success)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Allow,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new(("reject", tool_call_id.0), "Reject")
|
||||
.icon(IconName::X)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Error)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call_id;
|
||||
move |this, _, _, cx| {
|
||||
this.authorize_tool_call(
|
||||
id,
|
||||
acp::ToolCallConfirmationOutcome::Reject,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_diff_editor(&self, multibuffer: &Entity<MultiBuffer>) -> AnyElement {
|
||||
v_flex()
|
||||
.h_full()
|
||||
|
@ -1466,15 +1286,15 @@ impl AcpThreadView {
|
|||
.into_any()
|
||||
}
|
||||
|
||||
fn render_gemini_logo(&self) -> AnyElement {
|
||||
Icon::new(IconName::AiGemini)
|
||||
fn render_agent_logo(&self) -> AnyElement {
|
||||
Icon::new(self.agent.logo())
|
||||
.color(Color::Muted)
|
||||
.size(IconSize::XLarge)
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_error_gemini_logo(&self) -> AnyElement {
|
||||
let logo = Icon::new(IconName::AiGemini)
|
||||
fn render_error_agent_logo(&self) -> AnyElement {
|
||||
let logo = Icon::new(self.agent.logo())
|
||||
.color(Color::Muted)
|
||||
.size(IconSize::XLarge)
|
||||
.into_any_element();
|
||||
|
@ -1493,49 +1313,50 @@ impl AcpThreadView {
|
|||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_empty_state(&self, loading: bool, cx: &App) -> AnyElement {
|
||||
fn render_empty_state(&self, cx: &App) -> AnyElement {
|
||||
let loading = matches!(&self.thread_state, ThreadState::Loading { .. });
|
||||
|
||||
v_flex()
|
||||
.size_full()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(
|
||||
if loading {
|
||||
h_flex()
|
||||
.justify_center()
|
||||
.child(self.render_gemini_logo())
|
||||
.with_animation(
|
||||
"pulsating_icon",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 1.0)),
|
||||
|icon, delta| icon.opacity(delta),
|
||||
).into_any()
|
||||
} else {
|
||||
self.render_gemini_logo().into_any_element()
|
||||
}
|
||||
)
|
||||
.child(
|
||||
.child(if loading {
|
||||
h_flex()
|
||||
.mt_4()
|
||||
.mb_1()
|
||||
.justify_center()
|
||||
.child(Headline::new(if loading {
|
||||
"Connecting to Gemini…"
|
||||
} else {
|
||||
"Welcome to Gemini"
|
||||
}).size(HeadlineSize::Medium)),
|
||||
)
|
||||
.child(self.render_agent_logo())
|
||||
.with_animation(
|
||||
"pulsating_icon",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 1.0)),
|
||||
|icon, delta| icon.opacity(delta),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
self.render_agent_logo().into_any_element()
|
||||
})
|
||||
.child(h_flex().mt_4().mb_1().justify_center().child(if loading {
|
||||
div()
|
||||
.child(LoadingLabel::new("").size(LabelSize::Large))
|
||||
.into_any_element()
|
||||
} else {
|
||||
Headline::new(self.agent.empty_state_headline())
|
||||
.size(HeadlineSize::Medium)
|
||||
.into_any_element()
|
||||
}))
|
||||
.child(
|
||||
div()
|
||||
.max_w_1_2()
|
||||
.text_sm()
|
||||
.text_center()
|
||||
.map(|this| if loading {
|
||||
this.invisible()
|
||||
} else {
|
||||
this.text_color(cx.theme().colors().text_muted)
|
||||
.map(|this| {
|
||||
if loading {
|
||||
this.invisible()
|
||||
} else {
|
||||
this.text_color(cx.theme().colors().text_muted)
|
||||
}
|
||||
})
|
||||
.child("Ask questions, edit files, run commands.\nBe specific for the best results.")
|
||||
.child(self.agent.empty_state_message()),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
@ -1544,7 +1365,7 @@ impl AcpThreadView {
|
|||
v_flex()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_error_gemini_logo())
|
||||
.child(self.render_error_agent_logo())
|
||||
.child(
|
||||
h_flex()
|
||||
.mt_4()
|
||||
|
@ -1559,7 +1380,7 @@ impl AcpThreadView {
|
|||
let mut container = v_flex()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_error_gemini_logo())
|
||||
.child(self.render_error_agent_logo())
|
||||
.child(
|
||||
v_flex()
|
||||
.mt_4()
|
||||
|
@ -1575,43 +1396,47 @@ impl AcpThreadView {
|
|||
),
|
||||
);
|
||||
|
||||
if matches!(e, LoadError::Unsupported { .. }) {
|
||||
container =
|
||||
container.child(Button::new("upgrade", "Upgrade Gemini to Latest").on_click(
|
||||
cx.listener(|this, _, window, cx| {
|
||||
this.workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().read(cx);
|
||||
let cwd = project.first_project_directory(cx);
|
||||
let shell = project.terminal_settings(&cwd, cx).shell.clone();
|
||||
let command =
|
||||
"npm install -g @google/gemini-cli@latest".to_string();
|
||||
let spawn_in_terminal = task::SpawnInTerminal {
|
||||
id: task::TaskId("install".to_string()),
|
||||
full_label: command.clone(),
|
||||
label: command.clone(),
|
||||
command: Some(command.clone()),
|
||||
args: Vec::new(),
|
||||
command_label: command.clone(),
|
||||
cwd,
|
||||
env: Default::default(),
|
||||
use_new_terminal: true,
|
||||
allow_concurrent_runs: true,
|
||||
reveal: Default::default(),
|
||||
reveal_target: Default::default(),
|
||||
hide: Default::default(),
|
||||
shell,
|
||||
show_summary: true,
|
||||
show_command: true,
|
||||
show_rerun: false,
|
||||
};
|
||||
workspace
|
||||
.spawn_in_terminal(spawn_in_terminal, window, cx)
|
||||
.detach();
|
||||
})
|
||||
.ok();
|
||||
}),
|
||||
));
|
||||
if let LoadError::Unsupported {
|
||||
upgrade_message,
|
||||
upgrade_command,
|
||||
..
|
||||
} = &e
|
||||
{
|
||||
let upgrade_message = upgrade_message.clone();
|
||||
let upgrade_command = upgrade_command.clone();
|
||||
container = container.child(Button::new("upgrade", upgrade_message).on_click(
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().read(cx);
|
||||
let cwd = project.first_project_directory(cx);
|
||||
let shell = project.terminal_settings(&cwd, cx).shell.clone();
|
||||
let spawn_in_terminal = task::SpawnInTerminal {
|
||||
id: task::TaskId("install".to_string()),
|
||||
full_label: upgrade_command.clone(),
|
||||
label: upgrade_command.clone(),
|
||||
command: Some(upgrade_command.clone()),
|
||||
args: Vec::new(),
|
||||
command_label: upgrade_command.clone(),
|
||||
cwd,
|
||||
env: Default::default(),
|
||||
use_new_terminal: true,
|
||||
allow_concurrent_runs: true,
|
||||
reveal: Default::default(),
|
||||
reveal_target: Default::default(),
|
||||
hide: Default::default(),
|
||||
shell,
|
||||
show_summary: true,
|
||||
show_command: true,
|
||||
show_rerun: false,
|
||||
};
|
||||
workspace
|
||||
.spawn_in_terminal(spawn_in_terminal, window, cx)
|
||||
.detach();
|
||||
})
|
||||
.ok();
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
container.into_any()
|
||||
|
@ -2267,20 +2092,23 @@ impl Render for AcpThreadView {
|
|||
.on_action(cx.listener(Self::next_history_message))
|
||||
.on_action(cx.listener(Self::open_agent_diff))
|
||||
.child(match &self.thread_state {
|
||||
ThreadState::Unauthenticated { .. } => v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(h_flex().mt_1p5().justify_center().child(
|
||||
Button::new("sign-in", "Sign in to Gemini").on_click(
|
||||
cx.listener(|this, _, window, cx| this.authenticate(window, cx)),
|
||||
),
|
||||
)),
|
||||
ThreadState::Loading { .. } => {
|
||||
v_flex().flex_1().child(self.render_empty_state(true, cx))
|
||||
ThreadState::Unauthenticated { .. } => {
|
||||
v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(
|
||||
h_flex().mt_1p5().justify_center().child(
|
||||
Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.authenticate(window, cx)
|
||||
})),
|
||||
),
|
||||
)
|
||||
}
|
||||
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
||||
ThreadState::LoadError(e) => v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
|
@ -2321,7 +2149,7 @@ impl Render for AcpThreadView {
|
|||
})
|
||||
.children(self.render_edits_bar(&thread, window, cx))
|
||||
} else {
|
||||
this.child(self.render_empty_state(false, cx))
|
||||
this.child(self.render_empty_state(cx))
|
||||
}
|
||||
}),
|
||||
})
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll};
|
||||
use acp::{AcpThread, AcpThreadEvent};
|
||||
use acp_thread::{AcpThread, AcpThreadEvent};
|
||||
use agent::{Thread, ThreadEvent, ThreadSummary};
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::Result;
|
||||
|
@ -81,7 +81,7 @@ impl AgentDiffThread {
|
|||
match self {
|
||||
AgentDiffThread::Native(thread) => thread.read(cx).is_generating(),
|
||||
AgentDiffThread::AcpThread(thread) => {
|
||||
thread.read(cx).status() == acp::ThreadStatus::Generating
|
||||
thread.read(cx).status() == acp_thread::ThreadStatus::Generating
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,10 +5,11 @@ use std::rc::Rc;
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use agent_servers::AgentServer;
|
||||
use db::kvp::{Dismissable, KEY_VALUE_STORE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::NewAcpThread;
|
||||
use crate::NewExternalAgentThread;
|
||||
use crate::agent_diff::AgentDiffThread;
|
||||
use crate::language_model_selector::ToggleModelSelector;
|
||||
use crate::{
|
||||
|
@ -114,10 +115,12 @@ pub fn init(cx: &mut App) {
|
|||
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &NewAcpThread, window, cx| {
|
||||
.register_action(|workspace, action: &NewExternalAgentThread, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| panel.new_gemini_thread(window, cx));
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.new_external_thread(action.agent, window, cx)
|
||||
});
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
|
@ -136,7 +139,7 @@ pub fn init(cx: &mut App) {
|
|||
let thread = thread.read(cx).thread().clone();
|
||||
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx);
|
||||
}
|
||||
ActiveView::AcpThread { .. }
|
||||
ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
|
@ -200,7 +203,7 @@ enum ActiveView {
|
|||
message_editor: Entity<MessageEditor>,
|
||||
_subscriptions: Vec<gpui::Subscription>,
|
||||
},
|
||||
AcpThread {
|
||||
ExternalAgentThread {
|
||||
thread_view: Entity<AcpThreadView>,
|
||||
},
|
||||
TextThread {
|
||||
|
@ -222,9 +225,9 @@ enum WhichFontSize {
|
|||
impl ActiveView {
|
||||
pub fn which_font_size_used(&self) -> WhichFontSize {
|
||||
match self {
|
||||
ActiveView::Thread { .. } | ActiveView::AcpThread { .. } | ActiveView::History => {
|
||||
WhichFontSize::AgentFont
|
||||
}
|
||||
ActiveView::Thread { .. }
|
||||
| ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::History => WhichFontSize::AgentFont,
|
||||
ActiveView::TextThread { .. } => WhichFontSize::BufferFont,
|
||||
ActiveView::Configuration => WhichFontSize::None,
|
||||
}
|
||||
|
@ -255,7 +258,7 @@ impl ActiveView {
|
|||
thread.scroll_to_bottom(cx);
|
||||
});
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {}
|
||||
ActiveView::ExternalAgentThread { .. } => {}
|
||||
ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
|
@ -674,7 +677,7 @@ impl AgentPanel {
|
|||
.clone()
|
||||
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
|
||||
}
|
||||
ActiveView::AcpThread { .. }
|
||||
ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
|
@ -757,7 +760,7 @@ impl AgentPanel {
|
|||
ActiveView::Thread { thread, .. } => {
|
||||
thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
|
||||
}
|
||||
ActiveView::AcpThread { thread_view, .. } => {
|
||||
ActiveView::ExternalAgentThread { thread_view, .. } => {
|
||||
thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx));
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
|
||||
|
@ -767,7 +770,7 @@ impl AgentPanel {
|
|||
fn active_message_editor(&self) -> Option<&Entity<MessageEditor>> {
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { message_editor, .. } => Some(message_editor),
|
||||
ActiveView::AcpThread { .. }
|
||||
ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => None,
|
||||
|
@ -889,35 +892,77 @@ impl AgentPanel {
|
|||
context_editor.focus_handle(cx).focus(window);
|
||||
}
|
||||
|
||||
fn new_gemini_thread(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
fn new_external_thread(
|
||||
&mut self,
|
||||
agent_choice: Option<crate::ExternalAgent>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let workspace = self.workspace.clone();
|
||||
let project = self.project.clone();
|
||||
let message_history = self.acp_message_history.clone();
|
||||
|
||||
const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent";
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
struct LastUsedExternalAgent {
|
||||
agent: crate::ExternalAgent,
|
||||
}
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let thread_view = cx.new_window_entity(|window, cx| {
|
||||
crate::acp::AcpThreadView::new(
|
||||
workspace.clone(),
|
||||
project,
|
||||
message_history,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
let server: Rc<dyn AgentServer> = match agent_choice {
|
||||
Some(agent) => {
|
||||
cx.background_spawn(async move {
|
||||
if let Some(serialized) =
|
||||
serde_json::to_string(&LastUsedExternalAgent { agent }).log_err()
|
||||
{
|
||||
KEY_VALUE_STORE
|
||||
.write_kvp(LAST_USED_EXTERNAL_AGENT_KEY.to_string(), serialized)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
agent.server()
|
||||
}
|
||||
None => cx
|
||||
.background_spawn(async move {
|
||||
KEY_VALUE_STORE.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY)
|
||||
})
|
||||
.await
|
||||
.log_err()
|
||||
.flatten()
|
||||
.and_then(|value| {
|
||||
serde_json::from_str::<LastUsedExternalAgent>(&value).log_err()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
.agent
|
||||
.server(),
|
||||
};
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
let thread_view = cx.new(|cx| {
|
||||
crate::acp::AcpThreadView::new(
|
||||
server,
|
||||
workspace.clone(),
|
||||
project,
|
||||
message_history,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
this.set_active_view(
|
||||
ActiveView::AcpThread {
|
||||
ActiveView::ExternalAgentThread {
|
||||
thread_view: thread_view.clone(),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn deploy_rules_library(
|
||||
|
@ -1084,7 +1129,7 @@ impl AgentPanel {
|
|||
ActiveView::Thread { message_editor, .. } => {
|
||||
message_editor.focus_handle(cx).focus(window);
|
||||
}
|
||||
ActiveView::AcpThread { thread_view } => {
|
||||
ActiveView::ExternalAgentThread { thread_view } => {
|
||||
thread_view.focus_handle(cx).focus(window);
|
||||
}
|
||||
ActiveView::TextThread { context_editor, .. } => {
|
||||
|
@ -1211,7 +1256,7 @@ impl AgentPanel {
|
|||
})
|
||||
.log_err();
|
||||
}
|
||||
ActiveView::AcpThread { .. }
|
||||
ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
|
@ -1267,7 +1312,7 @@ impl AgentPanel {
|
|||
)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
ActiveView::AcpThread { thread_view } => {
|
||||
ActiveView::ExternalAgentThread { thread_view } => {
|
||||
thread_view
|
||||
.update(cx, |thread_view, cx| {
|
||||
thread_view.open_thread_as_markdown(workspace, window, cx)
|
||||
|
@ -1428,7 +1473,7 @@ impl AgentPanel {
|
|||
}
|
||||
})
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {}
|
||||
ActiveView::ExternalAgentThread { .. } => {}
|
||||
ActiveView::History | ActiveView::Configuration => {}
|
||||
}
|
||||
|
||||
|
@ -1517,7 +1562,7 @@ impl Focusable for AgentPanel {
|
|||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx),
|
||||
ActiveView::AcpThread { thread_view, .. } => thread_view.focus_handle(cx),
|
||||
ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx),
|
||||
ActiveView::History => self.history.focus_handle(cx),
|
||||
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
|
||||
ActiveView::Configuration => {
|
||||
|
@ -1674,9 +1719,11 @@ impl AgentPanel {
|
|||
.into_any_element(),
|
||||
}
|
||||
}
|
||||
ActiveView::AcpThread { thread_view } => Label::new(thread_view.read(cx).title(cx))
|
||||
.truncate()
|
||||
.into_any_element(),
|
||||
ActiveView::ExternalAgentThread { thread_view } => {
|
||||
Label::new(thread_view.read(cx).title(cx))
|
||||
.truncate()
|
||||
.into_any_element()
|
||||
}
|
||||
ActiveView::TextThread {
|
||||
title_editor,
|
||||
context_editor,
|
||||
|
@ -1811,7 +1858,7 @@ impl AgentPanel {
|
|||
|
||||
let active_thread = match &self.active_view {
|
||||
ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()),
|
||||
ActiveView::AcpThread { .. }
|
||||
ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => None,
|
||||
|
@ -1849,7 +1896,20 @@ impl AgentPanel {
|
|||
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
this.separator()
|
||||
.header("External Agents")
|
||||
.action("New Gemini Thread", NewAcpThread.boxed_clone())
|
||||
.action(
|
||||
"New Gemini Thread",
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Gemini),
|
||||
}
|
||||
.boxed_clone(),
|
||||
)
|
||||
.action(
|
||||
"New Claude Code Thread",
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::ClaudeCode),
|
||||
}
|
||||
.boxed_clone(),
|
||||
)
|
||||
});
|
||||
menu
|
||||
}))
|
||||
|
@ -2090,7 +2150,11 @@ impl AgentPanel {
|
|||
|
||||
Some(element.into_any_element())
|
||||
}
|
||||
_ => None,
|
||||
ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2119,7 +2183,7 @@ impl AgentPanel {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
ActiveView::ExternalAgentThread { .. } => {
|
||||
return false;
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {
|
||||
|
@ -2706,7 +2770,7 @@ impl AgentPanel {
|
|||
) -> Option<AnyElement> {
|
||||
let active_thread = match &self.active_view {
|
||||
ActiveView::Thread { thread, .. } => thread,
|
||||
ActiveView::AcpThread { .. } => {
|
||||
ActiveView::ExternalAgentThread { .. } => {
|
||||
return None;
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {
|
||||
|
@ -3055,7 +3119,7 @@ impl AgentPanel {
|
|||
.detach();
|
||||
});
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {
|
||||
ActiveView::ExternalAgentThread { .. } => {
|
||||
unimplemented!()
|
||||
}
|
||||
ActiveView::TextThread { context_editor, .. } => {
|
||||
|
@ -3077,7 +3141,7 @@ impl AgentPanel {
|
|||
let mut key_context = KeyContext::new_with_defaults();
|
||||
key_context.add("AgentPanel");
|
||||
match &self.active_view {
|
||||
ActiveView::AcpThread { .. } => key_context.add("acp_thread"),
|
||||
ActiveView::ExternalAgentThread { .. } => key_context.add("external_agent_thread"),
|
||||
ActiveView::TextThread { .. } => key_context.add("prompt_editor"),
|
||||
ActiveView::Thread { .. } | ActiveView::History | ActiveView::Configuration => {}
|
||||
}
|
||||
|
@ -3133,7 +3197,7 @@ impl Render for AgentPanel {
|
|||
});
|
||||
this.continue_conversation(window, cx);
|
||||
}
|
||||
ActiveView::AcpThread { .. } => {}
|
||||
ActiveView::ExternalAgentThread { .. } => {}
|
||||
ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
|
@ -3175,7 +3239,7 @@ impl Render for AgentPanel {
|
|||
})
|
||||
.child(h_flex().child(message_editor.clone()))
|
||||
.child(self.render_drag_target(cx)),
|
||||
ActiveView::AcpThread { thread_view, .. } => parent
|
||||
ActiveView::ExternalAgentThread { thread_view, .. } => parent
|
||||
.relative()
|
||||
.child(thread_view.clone())
|
||||
.child(self.render_drag_target(cx)),
|
||||
|
|
|
@ -25,6 +25,7 @@ mod thread_history;
|
|||
mod tool_compatibility;
|
||||
mod ui;
|
||||
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent::{Thread, ThreadId};
|
||||
|
@ -40,7 +41,7 @@ use language_model::{
|
|||
};
|
||||
use prompt_store::PromptBuilder;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
|
||||
pub use crate::active_thread::ActiveThread;
|
||||
|
@ -57,8 +58,6 @@ actions!(
|
|||
[
|
||||
/// Creates a new text-based conversation thread.
|
||||
NewTextThread,
|
||||
/// Creates a new external agent conversation thread.
|
||||
NewAcpThread,
|
||||
/// Toggles the context picker interface for adding files, symbols, or other context.
|
||||
ToggleContextPicker,
|
||||
/// Toggles the navigation menu for switching between threads and views.
|
||||
|
@ -133,6 +132,32 @@ pub struct NewThread {
|
|||
from_thread_id: Option<ThreadId>,
|
||||
}
|
||||
|
||||
/// Creates a new external agent conversation thread.
|
||||
#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)]
|
||||
#[action(namespace = agent)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct NewExternalAgentThread {
|
||||
/// Which agent to use for the conversation.
|
||||
agent: Option<ExternalAgent>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ExternalAgent {
|
||||
#[default]
|
||||
Gemini,
|
||||
ClaudeCode,
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
pub fn server(&self) -> Rc<dyn agent_servers::AgentServer> {
|
||||
match self {
|
||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Opens the profile management interface for configuring agent tools and settings.
|
||||
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)]
|
||||
#[action(namespace = agent)]
|
||||
|
|
|
@ -21,12 +21,14 @@ collections.workspace = true
|
|||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
log.workspace = true
|
||||
net.workspace = true
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
tempfile.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
|
|
@ -70,12 +70,12 @@ fn is_null_value<T: Serialize>(value: &T) -> bool {
|
|||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Request<'a, T> {
|
||||
jsonrpc: &'static str,
|
||||
id: RequestId,
|
||||
method: &'a str,
|
||||
pub struct Request<'a, T> {
|
||||
pub jsonrpc: &'static str,
|
||||
pub id: RequestId,
|
||||
pub method: &'a str,
|
||||
#[serde(skip_serializing_if = "is_null_value")]
|
||||
params: T,
|
||||
pub params: T,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
@ -88,18 +88,18 @@ struct AnyResponse<'a> {
|
|||
result: Option<&'a RawValue>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct Response<T> {
|
||||
jsonrpc: &'static str,
|
||||
id: RequestId,
|
||||
pub(crate) struct Response<T> {
|
||||
pub jsonrpc: &'static str,
|
||||
pub id: RequestId,
|
||||
#[serde(flatten)]
|
||||
value: CspResult<T>,
|
||||
pub value: CspResult<T>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum CspResult<T> {
|
||||
pub(crate) enum CspResult<T> {
|
||||
#[serde(rename = "result")]
|
||||
Ok(Option<T>),
|
||||
#[allow(dead_code)]
|
||||
|
@ -123,8 +123,9 @@ struct AnyNotification<'a> {
|
|||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Error {
|
||||
message: String,
|
||||
pub(crate) struct Error {
|
||||
pub message: String,
|
||||
pub code: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
pub mod client;
|
||||
pub mod listener;
|
||||
pub mod protocol;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
|
236
crates/context_server/src/listener.rs
Normal file
236
crates/context_server/src/listener.rs
Normal file
|
@ -0,0 +1,236 @@
|
|||
use ::serde::{Deserialize, Serialize};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use futures::{
|
||||
AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt,
|
||||
channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded},
|
||||
io::BufReader,
|
||||
select_biased,
|
||||
};
|
||||
use gpui::{App, AppContext, AsyncApp, Task};
|
||||
use net::async_net::{UnixListener, UnixStream};
|
||||
use serde_json::{json, value::RawValue};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
client::{CspResult, RequestId, Response},
|
||||
types::Request,
|
||||
};
|
||||
|
||||
pub struct McpServer {
|
||||
socket_path: PathBuf,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
|
||||
_server_task: Task<()>,
|
||||
}
|
||||
|
||||
type McpHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
|
||||
|
||||
impl McpServer {
|
||||
pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
|
||||
let task = cx.background_spawn(async move {
|
||||
let temp_dir = tempfile::Builder::new().prefix("zed-mcp").tempdir()?;
|
||||
let socket_path = temp_dir.path().join("mcp.sock");
|
||||
let listener = UnixListener::bind(&socket_path).context("creating mcp socket")?;
|
||||
|
||||
anyhow::Ok((temp_dir, socket_path, listener))
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (temp_dir, socket_path, listener) = task.await?;
|
||||
let handlers = Rc::new(RefCell::new(HashMap::default()));
|
||||
let server_task = cx.spawn({
|
||||
let handlers = handlers.clone();
|
||||
async move |cx| {
|
||||
while let Ok((stream, _)) = listener.accept().await {
|
||||
Self::serve_connection(stream, handlers.clone(), cx);
|
||||
}
|
||||
drop(temp_dir)
|
||||
}
|
||||
});
|
||||
Ok(Self {
|
||||
socket_path,
|
||||
_server_task: server_task,
|
||||
handlers: handlers.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle_request<R: Request>(
|
||||
&mut self,
|
||||
f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
|
||||
) {
|
||||
let f = Box::new(f);
|
||||
self.handlers.borrow_mut().insert(
|
||||
R::METHOD,
|
||||
Box::new(move |req_id, opt_params, cx| {
|
||||
let result = match opt_params {
|
||||
Some(params) => serde_json::from_str(params.get()),
|
||||
None => serde_json::from_value(serde_json::Value::Null),
|
||||
};
|
||||
|
||||
let params: R::Params = match result {
|
||||
Ok(params) => params,
|
||||
Err(e) => {
|
||||
return Task::ready(
|
||||
serde_json::to_string(&Response::<R::Response> {
|
||||
jsonrpc: "2.0",
|
||||
id: req_id,
|
||||
value: CspResult::Error(Some(crate::client::Error {
|
||||
message: format!("{e}"),
|
||||
code: -32700,
|
||||
})),
|
||||
})
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
};
|
||||
let task = f(params, cx);
|
||||
cx.background_spawn(async move {
|
||||
match task.await {
|
||||
Ok(result) => serde_json::to_string(&Response {
|
||||
jsonrpc: "2.0",
|
||||
id: req_id,
|
||||
value: CspResult::Ok(Some(result)),
|
||||
})
|
||||
.unwrap(),
|
||||
Err(e) => serde_json::to_string(&Response {
|
||||
jsonrpc: "2.0",
|
||||
id: req_id,
|
||||
value: CspResult::Error::<R::Response>(Some(crate::client::Error {
|
||||
message: format!("{e}"),
|
||||
code: -32603,
|
||||
})),
|
||||
})
|
||||
.unwrap(),
|
||||
}
|
||||
})
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn socket_path(&self) -> &Path {
|
||||
&self.socket_path
|
||||
}
|
||||
|
||||
fn serve_connection(
|
||||
stream: UnixStream,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let (read, write) = smol::io::split(stream);
|
||||
let (incoming_tx, mut incoming_rx) = unbounded();
|
||||
let (outgoing_tx, outgoing_rx) = unbounded();
|
||||
|
||||
cx.background_spawn(Self::handle_io(outgoing_rx, incoming_tx, write, read))
|
||||
.detach();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
while let Some(request) = incoming_rx.next().await {
|
||||
let Some(request_id) = request.id.clone() else {
|
||||
continue;
|
||||
};
|
||||
if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
|
||||
let outgoing_tx = outgoing_tx.clone();
|
||||
|
||||
if let Some(task) = cx
|
||||
.update(|cx| handler(request_id, request.params, cx))
|
||||
.log_err()
|
||||
{
|
||||
cx.spawn(async move |_| {
|
||||
let response = task.await;
|
||||
outgoing_tx.unbounded_send(response).ok();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
} else {
|
||||
outgoing_tx
|
||||
.unbounded_send(
|
||||
serde_json::to_string(&Response::<()> {
|
||||
jsonrpc: "2.0",
|
||||
id: request.id.unwrap(),
|
||||
value: CspResult::Error(Some(crate::client::Error {
|
||||
message: format!("unhandled method {}", request.method),
|
||||
code: -32601,
|
||||
})),
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
async fn handle_io(
|
||||
mut outgoing_rx: UnboundedReceiver<String>,
|
||||
incoming_tx: UnboundedSender<RawRequest>,
|
||||
mut outgoing_bytes: impl Unpin + AsyncWrite,
|
||||
incoming_bytes: impl Unpin + AsyncRead,
|
||||
) -> Result<()> {
|
||||
let mut output_reader = BufReader::new(incoming_bytes);
|
||||
let mut incoming_line = String::new();
|
||||
loop {
|
||||
select_biased! {
|
||||
message = outgoing_rx.next().fuse() => {
|
||||
if let Some(message) = message {
|
||||
log::trace!("send: {}", &message);
|
||||
outgoing_bytes.write_all(message.as_bytes()).await?;
|
||||
outgoing_bytes.write_all(&[b'\n']).await?;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
|
||||
if bytes_read? == 0 {
|
||||
break
|
||||
}
|
||||
log::trace!("recv: {}", &incoming_line);
|
||||
match serde_json::from_str(&incoming_line) {
|
||||
Ok(message) => {
|
||||
incoming_tx.unbounded_send(message).log_err();
|
||||
}
|
||||
Err(error) => {
|
||||
outgoing_bytes.write_all(serde_json::to_string(&json!({
|
||||
"jsonrpc": "2.0",
|
||||
"error": json!({
|
||||
"code": -32603,
|
||||
"message": format!("Failed to parse: {error}"),
|
||||
}),
|
||||
}))?.as_bytes()).await?;
|
||||
outgoing_bytes.write_all(&[b'\n']).await?;
|
||||
log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
|
||||
}
|
||||
}
|
||||
incoming_line.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct RawRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<RequestId>,
|
||||
method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
params: Option<Box<serde_json::value::RawValue>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct RawResponse {
|
||||
jsonrpc: &'static str,
|
||||
id: RequestId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<crate::client::Error>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
result: Option<Box<serde_json::value::RawValue>>,
|
||||
}
|
|
@ -153,7 +153,7 @@ pub struct InitializeParams {
|
|||
pub struct CallToolParams {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<HashMap<String, serde_json::Value>>,
|
||||
pub arguments: Option<serde_json::Value>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ pub enum IconName {
|
|||
Ai,
|
||||
AiAnthropic,
|
||||
AiBedrock,
|
||||
AiClaude,
|
||||
AiDeepSeek,
|
||||
AiEdit,
|
||||
AiGemini,
|
||||
|
|
20
crates/nc/Cargo.toml
Normal file
20
crates/nc/Cargo.toml
Normal file
|
@ -0,0 +1,20 @@
|
|||
[package]
|
||||
name = "nc"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/nc.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
net.workspace = true
|
||||
smol.workspace = true
|
||||
workspace-hack.workspace = true
|
1
crates/nc/LICENSE-GPL
Symbolic link
1
crates/nc/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
56
crates/nc/src/nc.rs
Normal file
56
crates/nc/src/nc.rs
Normal file
|
@ -0,0 +1,56 @@
|
|||
use anyhow::Result;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub fn main(_socket: &str) -> Result<()> {
|
||||
// It looks like we can't get an async stdio stream on Windows from smol.
|
||||
//
|
||||
// We decided to merge this with a panic on Windows since this is only used
|
||||
// by the experimental Claude Code Agent Server.
|
||||
//
|
||||
// We're tracking this internally, and we will address it before shipping the integration.
|
||||
panic!("--nc isn't yet supported on Windows");
|
||||
}
|
||||
|
||||
/// The main function for when Zed is running in netcat mode
|
||||
#[cfg(not(windows))]
|
||||
pub fn main(socket: &str) -> Result<()> {
|
||||
use futures::{AsyncReadExt as _, AsyncWriteExt as _, FutureExt as _, io::BufReader, select};
|
||||
use net::async_net::UnixStream;
|
||||
use smol::{Async, io::AsyncBufReadExt};
|
||||
|
||||
smol::block_on(async {
|
||||
let socket_stream = UnixStream::connect(socket).await?;
|
||||
let (socket_read, mut socket_write) = socket_stream.split();
|
||||
let mut socket_reader = BufReader::new(socket_read);
|
||||
|
||||
let mut stdout = Async::new(std::io::stdout())?;
|
||||
let stdin = Async::new(std::io::stdin())?;
|
||||
let mut stdin_reader = BufReader::new(stdin);
|
||||
|
||||
let mut socket_line = Vec::new();
|
||||
let mut stdin_line = Vec::new();
|
||||
|
||||
loop {
|
||||
select! {
|
||||
bytes_read = socket_reader.read_until(b'\n', &mut socket_line).fuse() => {
|
||||
if bytes_read? == 0 {
|
||||
break
|
||||
}
|
||||
stdout.write_all(&socket_line).await?;
|
||||
stdout.flush().await?;
|
||||
socket_line.clear();
|
||||
}
|
||||
bytes_read = stdin_reader.read_until(b'\n', &mut stdin_line).fuse() => {
|
||||
if bytes_read? == 0 {
|
||||
break
|
||||
}
|
||||
socket_write.write_all(&stdin_line).await?;
|
||||
socket_write.flush().await?;
|
||||
stdin_line.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
|
@ -95,6 +95,7 @@ svg_preview.workspace = true
|
|||
menu.workspace = true
|
||||
migrator.workspace = true
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
nc.workspace = true
|
||||
nix = { workspace = true, features = ["pthread", "signal"] }
|
||||
node_runtime.workspace = true
|
||||
notifications.workspace = true
|
||||
|
|
|
@ -175,6 +175,17 @@ pub fn main() {
|
|||
return;
|
||||
}
|
||||
|
||||
// `zed --nc` Makes zed operate in nc/netcat mode for use with MCP
|
||||
if let Some(socket) = &args.nc {
|
||||
match nc::main(socket) {
|
||||
Ok(()) => return,
|
||||
Err(err) => {
|
||||
eprintln!("Error: {}", err);
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// `zed --printenv` Outputs environment variables as JSON to stdout
|
||||
if args.printenv {
|
||||
util::shell_env::print_env();
|
||||
|
@ -1168,6 +1179,11 @@ struct Args {
|
|||
#[arg(long, hide = true)]
|
||||
askpass: Option<String>,
|
||||
|
||||
/// Used for the MCP Server, to remove the need for netcat as a dependency,
|
||||
/// by having Zed act like netcat communicating over a Unix socket.
|
||||
#[arg(long, hide = true)]
|
||||
nc: Option<String>,
|
||||
|
||||
/// Run zed in the foreground, only used on Windows, to match the behavior on macOS.
|
||||
#[arg(long)]
|
||||
#[cfg(target_os = "windows")]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue