Extract acp_connection
This commit is contained in:
parent
0f395df9a8
commit
ced3d09f10
11 changed files with 305 additions and 415 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -138,7 +138,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "agent-client-protocol"
|
name = "agent-client-protocol"
|
||||||
version = "0.0.12"
|
version = "0.0.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4255a06cc2414033d1fe4baf1968bcc8f16d7e5814f272b97779b5806d129142"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
|
|
|
@ -413,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||||
#
|
#
|
||||||
|
|
||||||
agentic-coding-protocol = "0.0.10"
|
agentic-coding-protocol = "0.0.10"
|
||||||
agent-client-protocol = {path="../agent-client-protocol"}
|
agent-client-protocol = "0.0.13"
|
||||||
aho-corasick = "1.1"
|
aho-corasick = "1.1"
|
||||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||||
any_vec = "0.14"
|
any_vec = "0.14"
|
||||||
|
|
|
@ -616,6 +616,7 @@ impl Error for LoadError {}
|
||||||
|
|
||||||
impl AcpThread {
|
impl AcpThread {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
|
title: impl Into<SharedString>,
|
||||||
connection: Rc<dyn AgentConnection>,
|
connection: Rc<dyn AgentConnection>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
session_id: acp::SessionId,
|
session_id: acp::SessionId,
|
||||||
|
@ -628,7 +629,7 @@ impl AcpThread {
|
||||||
shared_buffers: Default::default(),
|
shared_buffers: Default::default(),
|
||||||
entries: Default::default(),
|
entries: Default::default(),
|
||||||
plan: Default::default(),
|
plan: Default::default(),
|
||||||
title: connection.name().into(),
|
title: title.into(),
|
||||||
project,
|
project,
|
||||||
send_task: None,
|
send_task: None,
|
||||||
connection,
|
connection,
|
||||||
|
|
|
@ -9,8 +9,6 @@ use ui::App;
|
||||||
use crate::AcpThread;
|
use crate::AcpThread;
|
||||||
|
|
||||||
pub trait AgentConnection {
|
pub trait AgentConnection {
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
|
|
@ -367,10 +367,6 @@ pub struct OldAcpAgentConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentConnection for OldAcpAgentConnection {
|
impl AgentConnection for OldAcpAgentConnection {
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
self.name
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
@ -394,7 +390,7 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
let thread = cx.new(|cx| {
|
let thread = cx.new(|cx| {
|
||||||
let session_id = acp::SessionId("acp-old-no-id".into());
|
let session_id = acp::SessionId("acp-old-no-id".into());
|
||||||
AcpThread::new(self.clone(), project, session_id, cx)
|
AcpThread::new("Gemini", self.clone(), project, session_id, cx)
|
||||||
});
|
});
|
||||||
thread
|
thread
|
||||||
})
|
})
|
||||||
|
|
0
crates/agent_servers/acp
Normal file
0
crates/agent_servers/acp
Normal file
256
crates/agent_servers/src/acp_connection.rs
Normal file
256
crates/agent_servers/src/acp_connection.rs
Normal file
|
@ -0,0 +1,256 @@
|
||||||
|
use agent_client_protocol as acp;
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use collections::HashMap;
|
||||||
|
use context_server::listener::McpServerTool;
|
||||||
|
use context_server::types::requests;
|
||||||
|
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||||
|
use futures::channel::{mpsc, oneshot};
|
||||||
|
use project::Project;
|
||||||
|
use smol::stream::StreamExt as _;
|
||||||
|
use std::cell::RefCell;
|
||||||
|
use std::rc::Rc;
|
||||||
|
use std::{path::Path, sync::Arc};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||||
|
|
||||||
|
use crate::mcp_server::ZedMcpServer;
|
||||||
|
use crate::{AgentServerCommand, mcp_server};
|
||||||
|
use acp_thread::{AcpThread, AgentConnection};
|
||||||
|
|
||||||
|
pub struct AcpConnection {
|
||||||
|
server_name: &'static str,
|
||||||
|
client: Arc<context_server::ContextServer>,
|
||||||
|
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||||
|
_notification_handler_task: Task<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AcpConnection {
|
||||||
|
pub async fn stdio(
|
||||||
|
server_name: &'static str,
|
||||||
|
command: AgentServerCommand,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let client: Arc<ContextServer> = ContextServer::stdio(
|
||||||
|
ContextServerId(format!("{}-mcp-server", server_name).into()),
|
||||||
|
ContextServerCommand {
|
||||||
|
path: command.path,
|
||||||
|
args: command.args,
|
||||||
|
env: command.env,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
ContextServer::start(client.clone(), cx).await?;
|
||||||
|
|
||||||
|
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||||
|
client
|
||||||
|
.client()
|
||||||
|
.context("Failed to subscribe")?
|
||||||
|
.on_notification(acp::AGENT_METHODS.session_update, {
|
||||||
|
move |notification, _cx| {
|
||||||
|
let notification_tx = notification_tx.clone();
|
||||||
|
log::trace!(
|
||||||
|
"ACP Notification: {}",
|
||||||
|
serde_json::to_string_pretty(¬ification).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(notification) =
|
||||||
|
serde_json::from_value::<acp::SessionNotification>(notification).log_err()
|
||||||
|
{
|
||||||
|
notification_tx.unbounded_send(notification).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||||
|
|
||||||
|
let notification_handler_task = cx.spawn({
|
||||||
|
let sessions = sessions.clone();
|
||||||
|
async move |cx| {
|
||||||
|
while let Some(notification) = notification_rx.next().await {
|
||||||
|
Self::handle_session_notification(notification, sessions.clone(), cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
server_name,
|
||||||
|
client,
|
||||||
|
sessions,
|
||||||
|
_notification_handler_task: notification_handler_task,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handle_session_notification(
|
||||||
|
notification: acp::SessionNotification,
|
||||||
|
threads: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) {
|
||||||
|
let threads = threads.borrow();
|
||||||
|
let Some(thread) = threads
|
||||||
|
.get(¬ification.session_id)
|
||||||
|
.and_then(|session| session.thread.upgrade())
|
||||||
|
else {
|
||||||
|
log::error!(
|
||||||
|
"Thread not found for session ID: {}",
|
||||||
|
notification.session_id
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
thread.handle_session_update(notification.update, cx)
|
||||||
|
})
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AcpSession {
|
||||||
|
thread: WeakEntity<AcpThread>,
|
||||||
|
cancel_tx: Option<oneshot::Sender<()>>,
|
||||||
|
_mcp_server: ZedMcpServer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentConnection for AcpConnection {
|
||||||
|
fn new_thread(
|
||||||
|
self: Rc<Self>,
|
||||||
|
project: Entity<Project>,
|
||||||
|
cwd: &Path,
|
||||||
|
cx: &mut AsyncApp,
|
||||||
|
) -> Task<Result<Entity<AcpThread>>> {
|
||||||
|
let client = self.client.client();
|
||||||
|
let sessions = self.sessions.clone();
|
||||||
|
let cwd = cwd.to_path_buf();
|
||||||
|
cx.spawn(async move |cx| {
|
||||||
|
let client = client.context("MCP server is not initialized yet")?;
|
||||||
|
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
||||||
|
|
||||||
|
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.request::<requests::CallTool>(context_server::types::CallToolParams {
|
||||||
|
name: acp::AGENT_METHODS.new_session.into(),
|
||||||
|
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
|
||||||
|
mcp_servers: vec![mcp_server.server_config()?],
|
||||||
|
client_tools: acp::ClientTools {
|
||||||
|
request_permission: Some(acp::McpToolId {
|
||||||
|
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||||
|
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
|
||||||
|
}),
|
||||||
|
read_text_file: Some(acp::McpToolId {
|
||||||
|
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||||
|
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
|
||||||
|
}),
|
||||||
|
write_text_file: Some(acp::McpToolId {
|
||||||
|
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||||
|
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
cwd,
|
||||||
|
})?),
|
||||||
|
meta: None,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if response.is_error.unwrap_or_default() {
|
||||||
|
return Err(anyhow!(response.text_contents()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = serde_json::from_value::<acp::NewSessionOutput>(
|
||||||
|
response.structured_content.context("Empty response")?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let thread = cx.new(|cx| {
|
||||||
|
AcpThread::new(
|
||||||
|
self.server_name,
|
||||||
|
self.clone(),
|
||||||
|
project,
|
||||||
|
result.session_id.clone(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
thread_tx.send(thread.downgrade())?;
|
||||||
|
|
||||||
|
let session = AcpSession {
|
||||||
|
thread: thread.downgrade(),
|
||||||
|
cancel_tx: None,
|
||||||
|
_mcp_server: mcp_server,
|
||||||
|
};
|
||||||
|
sessions.borrow_mut().insert(result.session_id, session);
|
||||||
|
|
||||||
|
Ok(thread)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||||
|
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt(
|
||||||
|
&self,
|
||||||
|
params: agent_client_protocol::PromptArguments,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<()>> {
|
||||||
|
let client = self.client.client();
|
||||||
|
let sessions = self.sessions.clone();
|
||||||
|
|
||||||
|
cx.foreground_executor().spawn(async move {
|
||||||
|
let client = client.context("MCP server is not initialized yet")?;
|
||||||
|
|
||||||
|
let (new_cancel_tx, cancel_rx) = oneshot::channel();
|
||||||
|
{
|
||||||
|
let mut sessions = sessions.borrow_mut();
|
||||||
|
let session = sessions
|
||||||
|
.get_mut(¶ms.session_id)
|
||||||
|
.context("Session not found")?;
|
||||||
|
session.cancel_tx.replace(new_cancel_tx);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = client
|
||||||
|
.request_with::<requests::CallTool>(
|
||||||
|
context_server::types::CallToolParams {
|
||||||
|
name: acp::AGENT_METHODS.prompt.into(),
|
||||||
|
arguments: Some(serde_json::to_value(params)?),
|
||||||
|
meta: None,
|
||||||
|
},
|
||||||
|
Some(cancel_rx),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Err(err) = &result
|
||||||
|
&& err.is::<context_server::client::RequestCanceled>()
|
||||||
|
{
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = result?;
|
||||||
|
|
||||||
|
if response.is_error.unwrap_or_default() {
|
||||||
|
return Err(anyhow!(response.text_contents()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
|
||||||
|
let mut sessions = self.sessions.borrow_mut();
|
||||||
|
|
||||||
|
if let Some(cancel_tx) = sessions
|
||||||
|
.get_mut(session_id)
|
||||||
|
.and_then(|session| session.cancel_tx.take())
|
||||||
|
{
|
||||||
|
cancel_tx.send(()).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for AcpConnection {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.client.stop().log_err();
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
mod acp_connection;
|
||||||
mod claude;
|
mod claude;
|
||||||
mod codex;
|
mod codex;
|
||||||
mod gemini;
|
mod gemini;
|
||||||
|
|
|
@ -70,10 +70,6 @@ struct ClaudeAgentConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentConnection for ClaudeAgentConnection {
|
impl AgentConnection for ClaudeAgentConnection {
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
ClaudeCode.name()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
self: Rc<Self>,
|
self: Rc<Self>,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let thread =
|
let thread = cx.new(|cx| {
|
||||||
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
|
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
|
||||||
|
})?;
|
||||||
|
|
||||||
thread_tx.send(thread.downgrade())?;
|
thread_tx.send(thread.downgrade())?;
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,14 @@
|
||||||
use agent_client_protocol as acp;
|
|
||||||
use anyhow::anyhow;
|
|
||||||
use collections::HashMap;
|
|
||||||
use context_server::listener::McpServerTool;
|
|
||||||
use context_server::types::requests;
|
|
||||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
|
||||||
use futures::channel::{mpsc, oneshot};
|
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::stream::StreamExt as _;
|
use std::path::Path;
|
||||||
use std::cell::RefCell;
|
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::{path::Path, sync::Arc};
|
|
||||||
use util::ResultExt;
|
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::Result;
|
||||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
use gpui::{App, Entity, Task};
|
||||||
|
|
||||||
use crate::mcp_server::ZedMcpServer;
|
use crate::acp_connection::AcpConnection;
|
||||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
|
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
|
||||||
use acp_thread::{AcpThread, AgentConnection};
|
use acp_thread::AgentConnection;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Codex;
|
pub struct Codex;
|
||||||
|
@ -47,6 +37,7 @@ impl AgentServer for Codex {
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||||
let project = project.clone();
|
let project = project.clone();
|
||||||
|
let server_name = self.name();
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||||
|
@ -58,240 +49,12 @@ impl AgentServer for Codex {
|
||||||
anyhow::bail!("Failed to find codex binary");
|
anyhow::bail!("Failed to find codex binary");
|
||||||
};
|
};
|
||||||
|
|
||||||
let client: Arc<ContextServer> = ContextServer::stdio(
|
let conn = AcpConnection::stdio(server_name, command, cx).await?;
|
||||||
ContextServerId("codex-mcp-server".into()),
|
Ok(Rc::new(conn) as _)
|
||||||
ContextServerCommand {
|
|
||||||
path: command.path,
|
|
||||||
args: command.args,
|
|
||||||
env: command.env,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.into();
|
|
||||||
ContextServer::start(client.clone(), cx).await?;
|
|
||||||
|
|
||||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
|
||||||
client
|
|
||||||
.client()
|
|
||||||
.context("Failed to subscribe")?
|
|
||||||
.on_notification(acp::AGENT_METHODS.session_update, {
|
|
||||||
move |notification, _cx| {
|
|
||||||
let notification_tx = notification_tx.clone();
|
|
||||||
log::trace!(
|
|
||||||
"ACP Notification: {}",
|
|
||||||
serde_json::to_string_pretty(¬ification).unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(notification) =
|
|
||||||
serde_json::from_value::<acp::SessionNotification>(notification)
|
|
||||||
.log_err()
|
|
||||||
{
|
|
||||||
notification_tx.unbounded_send(notification).ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
|
||||||
|
|
||||||
let notification_handler_task = cx.spawn({
|
|
||||||
let sessions = sessions.clone();
|
|
||||||
async move |cx| {
|
|
||||||
while let Some(notification) = notification_rx.next().await {
|
|
||||||
CodexConnection::handle_session_notification(
|
|
||||||
notification,
|
|
||||||
sessions.clone(),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let connection = CodexConnection {
|
|
||||||
client,
|
|
||||||
sessions,
|
|
||||||
_notification_handler_task: notification_handler_task,
|
|
||||||
};
|
|
||||||
Ok(Rc::new(connection) as _)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CodexConnection {
|
|
||||||
client: Arc<context_server::ContextServer>,
|
|
||||||
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
|
||||||
_notification_handler_task: Task<()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CodexSession {
|
|
||||||
thread: WeakEntity<AcpThread>,
|
|
||||||
cancel_tx: Option<oneshot::Sender<()>>,
|
|
||||||
_mcp_server: ZedMcpServer,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AgentConnection for CodexConnection {
|
|
||||||
fn name(&self) -> &'static str {
|
|
||||||
"Codex"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_thread(
|
|
||||||
self: Rc<Self>,
|
|
||||||
project: Entity<Project>,
|
|
||||||
cwd: &Path,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Task<Result<Entity<AcpThread>>> {
|
|
||||||
let client = self.client.client();
|
|
||||||
let sessions = self.sessions.clone();
|
|
||||||
let cwd = cwd.to_path_buf();
|
|
||||||
cx.spawn(async move |cx| {
|
|
||||||
let client = client.context("MCP server is not initialized yet")?;
|
|
||||||
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
|
||||||
|
|
||||||
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
|
|
||||||
|
|
||||||
let response = client
|
|
||||||
.request::<requests::CallTool>(context_server::types::CallToolParams {
|
|
||||||
name: acp::AGENT_METHODS.new_session.into(),
|
|
||||||
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
|
|
||||||
mcp_servers: vec![mcp_server.server_config()?],
|
|
||||||
client_tools: acp::ClientTools {
|
|
||||||
request_permission: Some(acp::McpToolId {
|
|
||||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
|
||||||
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
|
|
||||||
}),
|
|
||||||
read_text_file: Some(acp::McpToolId {
|
|
||||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
|
||||||
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
|
|
||||||
}),
|
|
||||||
write_text_file: Some(acp::McpToolId {
|
|
||||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
|
||||||
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
cwd,
|
|
||||||
})?),
|
|
||||||
meta: None,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if response.is_error.unwrap_or_default() {
|
|
||||||
return Err(anyhow!(response.text_contents()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = serde_json::from_value::<acp::NewSessionOutput>(
|
|
||||||
response.structured_content.context("Empty response")?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let thread =
|
|
||||||
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
|
|
||||||
|
|
||||||
thread_tx.send(thread.downgrade())?;
|
|
||||||
|
|
||||||
let session = CodexSession {
|
|
||||||
thread: thread.downgrade(),
|
|
||||||
cancel_tx: None,
|
|
||||||
_mcp_server: mcp_server,
|
|
||||||
};
|
|
||||||
sessions.borrow_mut().insert(result.session_id, session);
|
|
||||||
|
|
||||||
Ok(thread)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
|
||||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prompt(
|
|
||||||
&self,
|
|
||||||
params: agent_client_protocol::PromptArguments,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Task<Result<()>> {
|
|
||||||
let client = self.client.client();
|
|
||||||
let sessions = self.sessions.clone();
|
|
||||||
|
|
||||||
cx.foreground_executor().spawn(async move {
|
|
||||||
let client = client.context("MCP server is not initialized yet")?;
|
|
||||||
|
|
||||||
let (new_cancel_tx, cancel_rx) = oneshot::channel();
|
|
||||||
{
|
|
||||||
let mut sessions = sessions.borrow_mut();
|
|
||||||
let session = sessions
|
|
||||||
.get_mut(¶ms.session_id)
|
|
||||||
.context("Session not found")?;
|
|
||||||
session.cancel_tx.replace(new_cancel_tx);
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = client
|
|
||||||
.request_with::<requests::CallTool>(
|
|
||||||
context_server::types::CallToolParams {
|
|
||||||
name: acp::AGENT_METHODS.prompt.into(),
|
|
||||||
arguments: Some(serde_json::to_value(params)?),
|
|
||||||
meta: None,
|
|
||||||
},
|
|
||||||
Some(cancel_rx),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if let Err(err) = &result
|
|
||||||
&& err.is::<context_server::client::RequestCanceled>()
|
|
||||||
{
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = result?;
|
|
||||||
|
|
||||||
if response.is_error.unwrap_or_default() {
|
|
||||||
return Err(anyhow!(response.text_contents()));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
|
|
||||||
let mut sessions = self.sessions.borrow_mut();
|
|
||||||
|
|
||||||
if let Some(cancel_tx) = sessions
|
|
||||||
.get_mut(session_id)
|
|
||||||
.and_then(|session| session.cancel_tx.take())
|
|
||||||
{
|
|
||||||
cancel_tx.send(()).ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CodexConnection {
|
|
||||||
pub fn handle_session_notification(
|
|
||||||
notification: acp::SessionNotification,
|
|
||||||
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) {
|
|
||||||
let threads = threads.borrow();
|
|
||||||
let Some(thread) = threads
|
|
||||||
.get(¬ification.session_id)
|
|
||||||
.and_then(|session| session.thread.upgrade())
|
|
||||||
else {
|
|
||||||
log::error!(
|
|
||||||
"Thread not found for session ID: {}",
|
|
||||||
notification.session_id
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
thread
|
|
||||||
.update(cx, |thread, cx| {
|
|
||||||
thread.handle_session_update(notification.update, cx)
|
|
||||||
})
|
|
||||||
.log_err();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for CodexConnection {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
self.client.stop().log_err();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) mod tests {
|
pub(crate) mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -1,25 +1,18 @@
|
||||||
use anyhow::anyhow;
|
|
||||||
use std::cell::RefCell;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::rc::Rc;
|
|
||||||
use util::ResultExt as _;
|
|
||||||
|
|
||||||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
|
|
||||||
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
|
|
||||||
use agentic_coding_protocol as acp_old;
|
|
||||||
use anyhow::{Context as _, Result};
|
|
||||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use ui::App;
|
use std::path::Path;
|
||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
use crate::AllAgentServersSettings;
|
use anyhow::Result;
|
||||||
|
use gpui::{App, Entity, Task};
|
||||||
|
|
||||||
|
use crate::acp_connection::AcpConnection;
|
||||||
|
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
|
||||||
|
use acp_thread::AgentConnection;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Gemini;
|
pub struct Gemini;
|
||||||
|
|
||||||
const ACP_ARG: &str = "--experimental-acp";
|
|
||||||
|
|
||||||
impl AgentServer for Gemini {
|
impl AgentServer for Gemini {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> &'static str {
|
||||||
"Gemini"
|
"Gemini"
|
||||||
|
@ -39,147 +32,33 @@ impl AgentServer for Gemini {
|
||||||
|
|
||||||
fn connect(
|
fn connect(
|
||||||
&self,
|
&self,
|
||||||
root_dir: &Path,
|
_root_dir: &Path,
|
||||||
project: &Entity<Project>,
|
project: &Entity<Project>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||||
let root_dir = root_dir.to_path_buf();
|
|
||||||
let project = project.clone();
|
let project = project.clone();
|
||||||
let this = self.clone();
|
let server_name = self.name();
|
||||||
let name = self.name();
|
|
||||||
|
|
||||||
cx.spawn(async move |cx| {
|
cx.spawn(async move |cx| {
|
||||||
let command = this.command(&project, cx).await?;
|
|
||||||
|
|
||||||
let mut child = util::command::new_smol_command(&command.path)
|
|
||||||
.args(command.args.iter())
|
|
||||||
.current_dir(root_dir)
|
|
||||||
.stdin(std::process::Stdio::piped())
|
|
||||||
.stdout(std::process::Stdio::piped())
|
|
||||||
.stderr(std::process::Stdio::inherit())
|
|
||||||
.kill_on_drop(true)
|
|
||||||
.spawn()?;
|
|
||||||
|
|
||||||
let stdin = child.stdin.take().unwrap();
|
|
||||||
let stdout = child.stdout.take().unwrap();
|
|
||||||
|
|
||||||
let foreground_executor = cx.foreground_executor().clone();
|
|
||||||
|
|
||||||
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
|
|
||||||
|
|
||||||
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
|
|
||||||
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
|
|
||||||
stdin,
|
|
||||||
stdout,
|
|
||||||
move |fut| foreground_executor.spawn(fut).detach(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let io_task = cx.background_spawn(async move {
|
|
||||||
io_fut.await.log_err();
|
|
||||||
});
|
|
||||||
|
|
||||||
let child_status = cx.background_spawn(async move {
|
|
||||||
let result = match child.status().await {
|
|
||||||
Err(e) => Err(anyhow!(e)),
|
|
||||||
Ok(result) if result.success() => Ok(()),
|
|
||||||
Ok(result) => {
|
|
||||||
if let Some(AgentServerVersion::Unsupported {
|
|
||||||
error_message,
|
|
||||||
upgrade_message,
|
|
||||||
upgrade_command,
|
|
||||||
}) = this.version(&command).await.log_err()
|
|
||||||
{
|
|
||||||
Err(anyhow!(LoadError::Unsupported {
|
|
||||||
error_message,
|
|
||||||
upgrade_message,
|
|
||||||
upgrade_command
|
|
||||||
}))
|
|
||||||
} else {
|
|
||||||
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
drop(io_task);
|
|
||||||
result
|
|
||||||
});
|
|
||||||
|
|
||||||
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
|
|
||||||
name,
|
|
||||||
connection,
|
|
||||||
child_status,
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(connection)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Gemini {
|
|
||||||
async fn command(
|
|
||||||
&self,
|
|
||||||
project: &Entity<Project>,
|
|
||||||
cx: &mut AsyncApp,
|
|
||||||
) -> Result<AgentServerCommand> {
|
|
||||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if let Some(command) =
|
let Some(command) = AgentServerCommand::resolve(
|
||||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
"gemini",
|
||||||
{
|
&["--experimental-mcp"],
|
||||||
return Ok(command);
|
settings,
|
||||||
|
&project,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
else {
|
||||||
|
anyhow::bail!("Failed to find gemini binary");
|
||||||
};
|
};
|
||||||
|
|
||||||
let (fs, node_runtime) = project.update(cx, |project, _| {
|
let conn = AcpConnection::stdio(server_name, command, cx).await?;
|
||||||
(project.fs().clone(), project.node_runtime().cloned())
|
Ok(Rc::new(conn) as _)
|
||||||
})?;
|
|
||||||
let node_runtime = node_runtime.context("gemini not found on path")?;
|
|
||||||
|
|
||||||
let directory = ::paths::agent_servers_dir().join("gemini");
|
|
||||||
fs.create_dir(&directory).await?;
|
|
||||||
node_runtime
|
|
||||||
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
|
|
||||||
.await?;
|
|
||||||
let path = directory.join("node_modules/.bin/gemini");
|
|
||||||
|
|
||||||
Ok(AgentServerCommand {
|
|
||||||
path,
|
|
||||||
args: vec![ACP_ARG.into()],
|
|
||||||
env: None,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
|
|
||||||
let version_fut = util::command::new_smol_command(&command.path)
|
|
||||||
.args(command.args.iter())
|
|
||||||
.arg("--version")
|
|
||||||
.kill_on_drop(true)
|
|
||||||
.output();
|
|
||||||
|
|
||||||
let help_fut = util::command::new_smol_command(&command.path)
|
|
||||||
.args(command.args.iter())
|
|
||||||
.arg("--help")
|
|
||||||
.kill_on_drop(true)
|
|
||||||
.output();
|
|
||||||
|
|
||||||
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
|
|
||||||
|
|
||||||
let current_version = String::from_utf8(version_output?.stdout)?;
|
|
||||||
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
|
|
||||||
|
|
||||||
if supported {
|
|
||||||
Ok(AgentServerVersion::Supported)
|
|
||||||
} else {
|
|
||||||
Ok(AgentServerVersion::Unsupported {
|
|
||||||
error_message: format!(
|
|
||||||
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
|
|
||||||
current_version
|
|
||||||
).into(),
|
|
||||||
upgrade_message: "Upgrade Gemini to Latest".into(),
|
|
||||||
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -188,17 +67,14 @@ pub(crate) mod tests {
|
||||||
use crate::AgentServerCommand;
|
use crate::AgentServerCommand;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
crate::common_e2e_tests!(Gemini, allow_option_id = "0");
|
crate::common_e2e_tests!(Gemini, allow_option_id = "allow");
|
||||||
|
|
||||||
pub fn local_command() -> AgentServerCommand {
|
pub fn local_command() -> AgentServerCommand {
|
||||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini/packages/cli");
|
||||||
.join("../../../gemini-cli/packages/cli")
|
|
||||||
.to_string_lossy()
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
AgentServerCommand {
|
AgentServerCommand {
|
||||||
path: "node".into(),
|
path: "node".into(),
|
||||||
args: vec![cli_path, ACP_ARG.into()],
|
args: vec![cli_path.to_string_lossy().to_string()],
|
||||||
env: None,
|
env: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue