Wire up stop button in claude threads (#34839)
Release Notes: - N/A --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
5b3e371812
commit
722a05bc21
3 changed files with 145 additions and 52 deletions
|
@ -4,10 +4,13 @@ mod tools;
|
|||
use collections::HashMap;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::process::Child;
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::pin::pin;
|
||||
use std::rc::Rc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use agentic_coding_protocol::{
|
||||
self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion,
|
||||
|
@ -16,7 +19,7 @@ use agentic_coding_protocol::{
|
|||
use anyhow::{Result, anyhow};
|
||||
use futures::channel::oneshot;
|
||||
use futures::future::LocalBoxFuture;
|
||||
use futures::{AsyncBufReadExt, AsyncWriteExt};
|
||||
use futures::{AsyncBufReadExt, AsyncWriteExt, SinkExt};
|
||||
use futures::{
|
||||
AsyncRead, AsyncWrite, FutureExt, StreamExt,
|
||||
channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
|
||||
|
@ -69,13 +72,12 @@ impl AgentServer for ClaudeCode {
|
|||
let (mut delegate_tx, delegate_rx) = watch::channel(None);
|
||||
let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let permission_mcp_server =
|
||||
ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
|
||||
let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
|
||||
|
||||
let mut mcp_servers = HashMap::default();
|
||||
mcp_servers.insert(
|
||||
mcp_server::SERVER_NAME.to_string(),
|
||||
permission_mcp_server.server_config()?,
|
||||
mcp_server.server_config()?,
|
||||
);
|
||||
let mcp_config = McpConfig { mcp_servers };
|
||||
|
||||
|
@ -98,50 +100,58 @@ impl AgentServer for ClaudeCode {
|
|||
anyhow::bail!("Failed to find claude binary");
|
||||
};
|
||||
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(
|
||||
[
|
||||
"--input-format",
|
||||
"stream-json",
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--print",
|
||||
"--verbose",
|
||||
"--mcp-config",
|
||||
mcp_config_path.to_string_lossy().as_ref(),
|
||||
"--permission-prompt-tool",
|
||||
&format!(
|
||||
"mcp__{}__{}",
|
||||
mcp_server::SERVER_NAME,
|
||||
mcp_server::PERMISSION_TOOL
|
||||
),
|
||||
"--allowedTools",
|
||||
"mcp__zed__Read,mcp__zed__Edit",
|
||||
"--disallowedTools",
|
||||
"Read,Edit",
|
||||
]
|
||||
.into_iter()
|
||||
.chain(command.args.iter().map(|arg| arg.as_str())),
|
||||
)
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
|
||||
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
|
||||
let (cancel_tx, mut cancel_rx) = mpsc::unbounded::<oneshot::Sender<Result<()>>>();
|
||||
|
||||
let session_id = Uuid::new_v4();
|
||||
|
||||
log::trace!("Starting session with id: {}", session_id);
|
||||
|
||||
let io_task =
|
||||
ClaudeAgentConnection::handle_io(outgoing_rx, incoming_message_tx, stdin, stdout);
|
||||
cx.background_spawn(async move {
|
||||
io_task.await.log_err();
|
||||
let mut outgoing_rx = Some(outgoing_rx);
|
||||
let mut mode = ClaudeSessionMode::Start;
|
||||
|
||||
loop {
|
||||
let mut child =
|
||||
spawn_claude(&command, mode, session_id, &mcp_config_path, &root_dir)
|
||||
.await?;
|
||||
mode = ClaudeSessionMode::Resume;
|
||||
|
||||
let pid = child.id();
|
||||
log::trace!("Spawned (pid: {})", pid);
|
||||
|
||||
let mut io_fut = pin!(
|
||||
ClaudeAgentConnection::handle_io(
|
||||
outgoing_rx.take().unwrap(),
|
||||
incoming_message_tx.clone(),
|
||||
child.stdin.take().unwrap(),
|
||||
child.stdout.take().unwrap(),
|
||||
)
|
||||
.fuse()
|
||||
);
|
||||
|
||||
select_biased! {
|
||||
done_tx = cancel_rx.next() => {
|
||||
if let Some(done_tx) = done_tx {
|
||||
log::trace!("Interrupted (pid: {})", pid);
|
||||
let result = send_interrupt(pid as i32);
|
||||
outgoing_rx.replace(io_fut.await?);
|
||||
done_tx.send(result).log_err();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
result = io_fut => {
|
||||
result?;
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!("Stopped (pid: {})", pid);
|
||||
break;
|
||||
}
|
||||
|
||||
drop(mcp_config_path);
|
||||
drop(child);
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
|
@ -171,17 +181,32 @@ impl AgentServer for ClaudeCode {
|
|||
delegate,
|
||||
outgoing_tx,
|
||||
end_turn_tx,
|
||||
cancel_tx,
|
||||
session_id,
|
||||
_handler_task: handler_task,
|
||||
_mcp_server: None,
|
||||
};
|
||||
|
||||
connection._mcp_server = Some(permission_mcp_server);
|
||||
connection._mcp_server = Some(mcp_server);
|
||||
acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> {
|
||||
let pid = nix::unistd::Pid::from_raw(pid);
|
||||
|
||||
nix::sys::signal::kill(pid, nix::sys::signal::SIGINT)
|
||||
.map_err(|e| anyhow!("Failed to interrupt process: {}", e))
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn send_interrupt(_pid: i32) -> anyhow::Result<()> {
|
||||
panic!("Cancel not implemented on Windows")
|
||||
}
|
||||
|
||||
impl AgentConnection for ClaudeAgentConnection {
|
||||
/// Send a request to the agent and wait for a response.
|
||||
fn request_any(
|
||||
|
@ -191,6 +216,8 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
let delegate = self.delegate.clone();
|
||||
let end_turn_tx = self.end_turn_tx.clone();
|
||||
let outgoing_tx = self.outgoing_tx.clone();
|
||||
let mut cancel_tx = self.cancel_tx.clone();
|
||||
let session_id = self.session_id;
|
||||
async move {
|
||||
match params {
|
||||
// todo: consider sending an empty request so we get the init response?
|
||||
|
@ -229,26 +256,83 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
stop_sequence: None,
|
||||
usage: None,
|
||||
},
|
||||
session_id: None,
|
||||
session_id: Some(session_id),
|
||||
})?;
|
||||
rx.await??;
|
||||
Ok(AnyAgentResult::SendUserMessageResponse(
|
||||
acp::SendUserMessageResponse,
|
||||
))
|
||||
}
|
||||
AnyAgentRequest::CancelSendMessageParams(_) => Ok(
|
||||
AnyAgentResult::CancelSendMessageResponse(acp::CancelSendMessageResponse),
|
||||
),
|
||||
AnyAgentRequest::CancelSendMessageParams(_) => {
|
||||
let (done_tx, done_rx) = oneshot::channel();
|
||||
cancel_tx.send(done_tx).await?;
|
||||
done_rx.await??;
|
||||
|
||||
Ok(AnyAgentResult::CancelSendMessageResponse(
|
||||
acp::CancelSendMessageResponse,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ClaudeSessionMode {
|
||||
Start,
|
||||
Resume,
|
||||
}
|
||||
|
||||
async fn spawn_claude(
|
||||
command: &AgentServerCommand,
|
||||
mode: ClaudeSessionMode,
|
||||
session_id: Uuid,
|
||||
mcp_config_path: &Path,
|
||||
root_dir: &Path,
|
||||
) -> Result<Child> {
|
||||
let child = util::command::new_smol_command(&command.path)
|
||||
.args([
|
||||
"--input-format",
|
||||
"stream-json",
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--print",
|
||||
"--verbose",
|
||||
"--mcp-config",
|
||||
mcp_config_path.to_string_lossy().as_ref(),
|
||||
"--permission-prompt-tool",
|
||||
&format!(
|
||||
"mcp__{}__{}",
|
||||
mcp_server::SERVER_NAME,
|
||||
mcp_server::PERMISSION_TOOL
|
||||
),
|
||||
"--allowedTools",
|
||||
"mcp__zed__Read,mcp__zed__Edit",
|
||||
"--disallowedTools",
|
||||
"Read,Edit",
|
||||
])
|
||||
.args(match mode {
|
||||
ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
|
||||
ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
|
||||
})
|
||||
.args(command.args.iter().map(|arg| arg.as_str()))
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
Ok(child)
|
||||
}
|
||||
|
||||
struct ClaudeAgentConnection {
|
||||
delegate: AcpClientDelegate,
|
||||
session_id: Uuid,
|
||||
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
||||
cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>,
|
||||
_mcp_server: Option<ClaudeMcpServer>,
|
||||
_handler_task: Task<()>,
|
||||
}
|
||||
|
@ -350,7 +434,7 @@ impl ClaudeAgentConnection {
|
|||
incoming_tx: UnboundedSender<SdkMessage>,
|
||||
mut outgoing_bytes: impl Unpin + AsyncWrite,
|
||||
incoming_bytes: impl Unpin + AsyncRead,
|
||||
) -> Result<()> {
|
||||
) -> Result<UnboundedReceiver<SdkMessage>> {
|
||||
let mut output_reader = BufReader::new(incoming_bytes);
|
||||
let mut outgoing_line = Vec::new();
|
||||
let mut incoming_line = String::new();
|
||||
|
@ -384,7 +468,8 @@ impl ClaudeAgentConnection {
|
|||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
||||
Ok(outgoing_rx)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -507,14 +592,14 @@ enum SdkMessage {
|
|||
Assistant {
|
||||
message: Message, // from Anthropic SDK
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
session_id: Option<Uuid>,
|
||||
},
|
||||
|
||||
// A user message
|
||||
User {
|
||||
message: Message, // from Anthropic SDK
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
session_id: Option<Uuid>,
|
||||
},
|
||||
|
||||
// Emitted as the last message in a conversation
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue