agent_servers: Use built-in interrupt handling for Claude sessions (#35154)
We no longer have to stop and restart the entire process. I left in the Start/Resume mode handling since we will likely need to handle restarting Claude in other situations. Release Notes: - N/A
This commit is contained in:
parent
89e88c245e
commit
a5b7cfd128
3 changed files with 58 additions and 78 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -168,6 +168,7 @@ dependencies = [
|
||||||
"nix 0.29.0",
|
"nix 0.29.0",
|
||||||
"paths",
|
"paths",
|
||||||
"project",
|
"project",
|
||||||
|
"rand 0.8.5",
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -29,6 +29,7 @@ itertools.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
paths.workspace = true
|
paths.workspace = true
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
|
rand.workspace = true
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
|
|
@ -9,7 +9,6 @@ use smol::process::Child;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::pin::pin;
|
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
@ -66,19 +65,6 @@ impl AgentServer for ClaudeCode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
|
||||||
fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> {
|
|
||||||
let pid = nix::unistd::Pid::from_raw(pid);
|
|
||||||
|
|
||||||
nix::sys::signal::kill(pid, nix::sys::signal::SIGINT)
|
|
||||||
.map_err(|e| anyhow!("Failed to interrupt process: {}", e))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(windows)]
|
|
||||||
fn send_interrupt(_pid: i32) -> anyhow::Result<()> {
|
|
||||||
panic!("Cancel not implemented on Windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ClaudeAgentConnection {
|
struct ClaudeAgentConnection {
|
||||||
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
|
sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
|
||||||
}
|
}
|
||||||
|
@ -127,7 +113,6 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
|
|
||||||
let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
|
let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
|
||||||
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
|
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
|
||||||
let (cancel_tx, mut cancel_rx) = mpsc::unbounded::<oneshot::Sender<Result<()>>>();
|
|
||||||
|
|
||||||
let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
|
let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
|
||||||
|
|
||||||
|
@ -137,50 +122,28 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
let session_id = session_id.clone();
|
let session_id = session_id.clone();
|
||||||
async move {
|
async move {
|
||||||
let mut outgoing_rx = Some(outgoing_rx);
|
let mut outgoing_rx = Some(outgoing_rx);
|
||||||
let mut mode = ClaudeSessionMode::Start;
|
|
||||||
|
|
||||||
loop {
|
let mut child = spawn_claude(
|
||||||
let mut child = spawn_claude(
|
&command,
|
||||||
&command,
|
ClaudeSessionMode::Start,
|
||||||
mode,
|
session_id.clone(),
|
||||||
session_id.clone(),
|
&mcp_config_path,
|
||||||
&mcp_config_path,
|
&cwd,
|
||||||
&cwd,
|
)
|
||||||
)
|
.await?;
|
||||||
.await?;
|
|
||||||
mode = ClaudeSessionMode::Resume;
|
|
||||||
|
|
||||||
let pid = child.id();
|
let pid = child.id();
|
||||||
log::trace!("Spawned (pid: {})", pid);
|
log::trace!("Spawned (pid: {})", pid);
|
||||||
|
|
||||||
let mut io_fut = pin!(
|
ClaudeAgentSession::handle_io(
|
||||||
ClaudeAgentSession::handle_io(
|
outgoing_rx.take().unwrap(),
|
||||||
outgoing_rx.take().unwrap(),
|
incoming_message_tx.clone(),
|
||||||
incoming_message_tx.clone(),
|
child.stdin.take().unwrap(),
|
||||||
child.stdin.take().unwrap(),
|
child.stdout.take().unwrap(),
|
||||||
child.stdout.take().unwrap(),
|
)
|
||||||
)
|
.await?;
|
||||||
.fuse()
|
|
||||||
);
|
|
||||||
|
|
||||||
select_biased! {
|
log::trace!("Stopped (pid: {})", pid);
|
||||||
done_tx = cancel_rx.next() => {
|
|
||||||
if let Some(done_tx) = done_tx {
|
|
||||||
log::trace!("Interrupted (pid: {})", pid);
|
|
||||||
let result = send_interrupt(pid as i32);
|
|
||||||
outgoing_rx.replace(io_fut.await?);
|
|
||||||
done_tx.send(result).log_err();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result = io_fut => {
|
|
||||||
result?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log::trace!("Stopped (pid: {})", pid);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
drop(mcp_config_path);
|
drop(mcp_config_path);
|
||||||
anyhow::Ok(())
|
anyhow::Ok(())
|
||||||
|
@ -213,7 +176,6 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
let session = ClaudeAgentSession {
|
let session = ClaudeAgentSession {
|
||||||
outgoing_tx,
|
outgoing_tx,
|
||||||
end_turn_tx,
|
end_turn_tx,
|
||||||
cancel_tx,
|
|
||||||
_handler_task: handler_task,
|
_handler_task: handler_task,
|
||||||
_mcp_server: Some(permission_mcp_server),
|
_mcp_server: Some(permission_mcp_server),
|
||||||
};
|
};
|
||||||
|
@ -278,37 +240,24 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
||||||
let sessions = self.sessions.borrow();
|
let sessions = self.sessions.borrow();
|
||||||
let Some(session) = sessions.get(&session_id) else {
|
let Some(session) = sessions.get(&session_id) else {
|
||||||
log::warn!("Attempted to cancel nonexistent session {}", session_id);
|
log::warn!("Attempted to cancel nonexistent session {}", session_id);
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
let (done_tx, done_rx) = oneshot::channel();
|
session
|
||||||
if session
|
.outgoing_tx
|
||||||
.cancel_tx
|
.unbounded_send(SdkMessage::new_interrupt_message())
|
||||||
.unbounded_send(done_tx)
|
.log_err();
|
||||||
.log_err()
|
|
||||||
.is_some()
|
|
||||||
{
|
|
||||||
let end_turn_tx = session.end_turn_tx.clone();
|
|
||||||
cx.foreground_executor()
|
|
||||||
.spawn(async move {
|
|
||||||
done_rx.await??;
|
|
||||||
if let Some(end_turn_tx) = end_turn_tx.take() {
|
|
||||||
end_turn_tx.send(Ok(())).ok();
|
|
||||||
}
|
|
||||||
anyhow::Ok(())
|
|
||||||
})
|
|
||||||
.detach_and_log_err(cx);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
enum ClaudeSessionMode {
|
enum ClaudeSessionMode {
|
||||||
Start,
|
Start,
|
||||||
|
#[expect(dead_code)]
|
||||||
Resume,
|
Resume,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -364,7 +313,6 @@ async fn spawn_claude(
|
||||||
struct ClaudeAgentSession {
|
struct ClaudeAgentSession {
|
||||||
outgoing_tx: UnboundedSender<SdkMessage>,
|
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<()>>>>>,
|
||||||
cancel_tx: UnboundedSender<oneshot::Sender<Result<()>>>,
|
|
||||||
_mcp_server: Option<ClaudeZedMcpServer>,
|
_mcp_server: Option<ClaudeZedMcpServer>,
|
||||||
_handler_task: Task<()>,
|
_handler_task: Task<()>,
|
||||||
}
|
}
|
||||||
|
@ -377,6 +325,8 @@ impl ClaudeAgentSession {
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) {
|
) {
|
||||||
match message {
|
match message {
|
||||||
|
// we should only be sending these out, they don't need to be in the thread
|
||||||
|
SdkMessage::ControlRequest { .. } => {}
|
||||||
SdkMessage::Assistant {
|
SdkMessage::Assistant {
|
||||||
message,
|
message,
|
||||||
session_id: _,
|
session_id: _,
|
||||||
|
@ -643,14 +593,12 @@ enum SdkMessage {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
session_id: Option<String>,
|
session_id: Option<String>,
|
||||||
},
|
},
|
||||||
|
|
||||||
// A user message
|
// A user message
|
||||||
User {
|
User {
|
||||||
message: Message, // from Anthropic SDK
|
message: Message, // from Anthropic SDK
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
session_id: Option<String>,
|
session_id: Option<String>,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Emitted as the last message in a conversation
|
// Emitted as the last message in a conversation
|
||||||
Result {
|
Result {
|
||||||
subtype: ResultErrorType,
|
subtype: ResultErrorType,
|
||||||
|
@ -675,6 +623,18 @@ enum SdkMessage {
|
||||||
#[serde(rename = "permissionMode")]
|
#[serde(rename = "permissionMode")]
|
||||||
permission_mode: PermissionMode,
|
permission_mode: PermissionMode,
|
||||||
},
|
},
|
||||||
|
/// Messages used to control the conversation, outside of chat messages to the model
|
||||||
|
ControlRequest {
|
||||||
|
request_id: String,
|
||||||
|
request: ControlRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "subtype", rename_all = "snake_case")]
|
||||||
|
enum ControlRequest {
|
||||||
|
/// Cancel the current conversation
|
||||||
|
Interrupt,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
@ -695,6 +655,24 @@ impl Display for ResultErrorType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl SdkMessage {
|
||||||
|
fn new_interrupt_message() -> Self {
|
||||||
|
use rand::Rng;
|
||||||
|
// In the Claude Code TS SDK they just generate a random 12 character string,
|
||||||
|
// `Math.random().toString(36).substring(2, 15)`
|
||||||
|
let request_id = rand::thread_rng()
|
||||||
|
.sample_iter(&rand::distributions::Alphanumeric)
|
||||||
|
.take(12)
|
||||||
|
.map(char::from)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Self::ControlRequest {
|
||||||
|
request_id,
|
||||||
|
request: ControlRequest::Interrupt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
struct McpServer {
|
struct McpServer {
|
||||||
name: String,
|
name: String,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue