This commit is contained in:
Antonio Scandurra 2025-08-18 13:45:55 +02:00
parent 501e72e8f0
commit 5259c8692d
6 changed files with 250 additions and 86 deletions

View file

@ -1,9 +1,9 @@
use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME; use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
use crate::{ use crate::{
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, UserMessageContent, WebSearchTool, templates::Templates,
}; };
use crate::{DbThread, ThreadsDatabase}; use crate::{DbThread, ThreadsDatabase};
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
@ -461,10 +461,7 @@ impl NativeAgentConnection {
session_id: acp::SessionId, session_id: acp::SessionId,
cx: &mut App, cx: &mut App,
f: impl 'static f: impl 'static
+ FnOnce( + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
Entity<Thread>,
&mut App,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
) -> Task<Result<acp::PromptResponse>> { ) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent agent
@ -488,7 +485,10 @@ impl NativeAgentConnection {
log::trace!("Received completion event: {:?}", event); log::trace!("Received completion event: {:?}", event);
match event { match event {
AgentResponseEvent::Text(text) => { ThreadEvent::UserMessage(message) => {
todo!()
}
ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block( thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent { acp::ContentBlock::Text(acp::TextContent {
@ -500,7 +500,7 @@ impl NativeAgentConnection {
) )
})?; })?;
} }
AgentResponseEvent::Thinking(text) => { ThreadEvent::AgentThinking(text) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block( thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent { acp::ContentBlock::Text(acp::TextContent {
@ -512,7 +512,7 @@ impl NativeAgentConnection {
) )
})?; })?;
} }
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call, tool_call,
options, options,
response, response,
@ -535,17 +535,17 @@ impl NativeAgentConnection {
}) })
.detach(); .detach();
} }
AgentResponseEvent::ToolCall(tool_call) => { ThreadEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx) thread.upsert_tool_call(tool_call, cx)
})??; })??;
} }
AgentResponseEvent::ToolCallUpdate(update) => { ThreadEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx) thread.update_tool_call(update, cx)
})??; })??;
} }
AgentResponseEvent::Stop(stop_reason) => { ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason); log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason }); return Ok(acp::PromptResponse { stop_reason });
} }
@ -786,7 +786,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.into_iter() .into_iter()
.map(|thread| AcpThreadMetadata { .map(|thread| AcpThreadMetadata {
agent: NATIVE_AGENT_SERVER_NAME.clone(), agent: NATIVE_AGENT_SERVER_NAME.clone(),
id: thread.id, id: thread.id.into(),
title: thread.title, title: thread.title,
updated_at: thread.updated_at, updated_at: thread.updated_at,
}) })
@ -806,11 +806,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
session_id: acp::SessionId, session_id: acp::SessionId,
cx: &mut App, cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> { ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let thread_id = session_id.clone().into();
let database = self.0.update(cx, |this, _| this.thread_database.clone()); let database = self.0.update(cx, |this, _| this.thread_database.clone());
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let database = database.await.map_err(|e| anyhow!(e))?; let database = database.await.map_err(|e| anyhow!(e))?;
let db_thread = database let db_thread = database
.load_thread(session_id.clone()) .load_thread(thread_id)
.await? .await?
.context("no such thread found")?; .context("no such thread found")?;

View file

@ -1,4 +1,4 @@
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent};
use agent::thread_store; use agent::thread_store;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, CompletionMode}; use agent_settings::{AgentProfileId, CompletionMode};
@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbThreadMetadata { pub struct DbThreadMetadata {
pub id: acp::SessionId, pub id: ThreadId,
#[serde(alias = "summary")] #[serde(alias = "summary")]
pub title: SharedString, pub title: SharedString,
pub updated_at: DateTime<Utc>, pub updated_at: DateTime<Utc>,
@ -323,7 +323,7 @@ impl ThreadsDatabase {
for (id, summary, updated_at) in rows { for (id, summary, updated_at) in rows {
threads.push(DbThreadMetadata { threads.push(DbThreadMetadata {
id: acp::SessionId(id), id: ThreadId(id),
title: summary.into(), title: summary.into(),
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
}); });
@ -333,7 +333,7 @@ impl ThreadsDatabase {
}) })
} }
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> { pub fn load_thread(&self, id: ThreadId) -> Task<Result<Option<DbThread>>> {
let connection = self.connection.clone(); let connection = self.connection.clone();
self.executor.spawn(async move { self.executor.spawn(async move {

View file

@ -329,7 +329,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let mut saw_partial_tool_use = false; let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message // Look for a tool use in the thread's last message
let message = thread.last_message().unwrap(); let message = thread.last_message().unwrap();
@ -710,7 +710,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
} }
async fn expect_tool_call( async fn expect_tool_call(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>, events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> acp::ToolCall { ) -> acp::ToolCall {
let event = events let event = events
.next() .next()
@ -718,7 +718,7 @@ async fn expect_tool_call(
.expect("no tool call authorization event received") .expect("no tool call authorization event received")
.unwrap(); .unwrap();
match event { match event {
AgentResponseEvent::ToolCall(tool_call) => return tool_call, ThreadEvent::ToolCall(tool_call) => return tool_call,
event => { event => {
panic!("Unexpected event {event:?}"); panic!("Unexpected event {event:?}");
} }
@ -726,7 +726,7 @@ async fn expect_tool_call(
} }
async fn expect_tool_call_update_fields( async fn expect_tool_call_update_fields(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>, events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> acp::ToolCallUpdate { ) -> acp::ToolCallUpdate {
let event = events let event = events
.next() .next()
@ -734,7 +734,7 @@ async fn expect_tool_call_update_fields(
.expect("no tool call authorization event received") .expect("no tool call authorization event received")
.unwrap(); .unwrap();
match event { match event {
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
return update; return update;
} }
event => { event => {
@ -744,7 +744,7 @@ async fn expect_tool_call_update_fields(
} }
async fn next_tool_call_authorization( async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>, events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> ToolCallAuthorization { ) -> ToolCallAuthorization {
loop { loop {
let event = events let event = events
@ -752,7 +752,7 @@ async fn next_tool_call_authorization(
.await .await
.expect("no tool call authorization event received") .expect("no tool call authorization event received")
.unwrap(); .unwrap();
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization let permission_kinds = tool_call_authorization
.options .options
.iter() .iter()
@ -912,13 +912,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
let mut echo_completed = false; let mut echo_completed = false;
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
match event.unwrap() { match event.unwrap() {
AgentResponseEvent::ToolCall(tool_call) => { ThreadEvent::ToolCall(tool_call) => {
assert_eq!(tool_call.title, expected_tools.remove(0)); assert_eq!(tool_call.title, expected_tools.remove(0));
if tool_call.title == "Echo" { if tool_call.title == "Echo" {
echo_id = Some(tool_call.id); echo_id = Some(tool_call.id);
} }
} }
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
acp::ToolCallUpdate { acp::ToolCallUpdate {
id, id,
fields: fields:
@ -946,7 +946,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
assert!( assert!(
matches!( matches!(
last_event, last_event,
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
), ),
"unexpected event {last_event:?}" "unexpected event {last_event:?}"
); );
@ -1386,11 +1386,11 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
} }
/// Filters out the stop events for asserting against in tests /// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> { fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
result_events result_events
.into_iter() .into_iter()
.filter_map(|event| match event.unwrap() { .filter_map(|event| match event.unwrap() {
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), ThreadEvent::Stop(stop_reason) => Some(stop_reason),
_ => None, _ => None,
}) })
.collect() .collect()

View file

@ -1,4 +1,4 @@
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
use acp_thread::{MentionUri, UserMessageId}; use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog; use action_log::ActionLog;
use agent_client_protocol as acp; use agent_client_protocol as acp;
@ -30,10 +30,12 @@ use std::{fmt::Write, ops::Range};
use util::{ResultExt, markdown::MarkdownCodeBlock}; use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid; use uuid::Uuid;
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
#[derive( #[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
)] )]
pub struct ThreadId(Arc<str>); pub struct ThreadId(pub(crate) Arc<str>);
impl ThreadId { impl ThreadId {
pub fn new() -> Self { pub fn new() -> Self {
@ -53,6 +55,18 @@ impl From<&str> for ThreadId {
} }
} }
impl From<acp::SessionId> for ThreadId {
fn from(value: acp::SessionId) -> Self {
Self(value.0)
}
}
impl From<ThreadId> for acp::SessionId {
fn from(value: ThreadId) -> Self {
Self(value.0)
}
}
/// The ID of the user prompt that initiated a request. /// The ID of the user prompt that initiated a request.
/// ///
/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). /// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
@ -313,9 +327,6 @@ impl AgentMessage {
AgentMessageContent::RedactedThinking(_) => { AgentMessageContent::RedactedThinking(_) => {
markdown.push_str("<redacted_thinking />\n") markdown.push_str("<redacted_thinking />\n")
} }
AgentMessageContent::Image(_) => {
markdown.push_str("<image />\n");
}
AgentMessageContent::ToolUse(tool_use) => { AgentMessageContent::ToolUse(tool_use) => {
markdown.push_str(&format!( markdown.push_str(&format!(
"**Tool Use**: {} (ID: {})\n", "**Tool Use**: {} (ID: {})\n",
@ -386,9 +397,6 @@ impl AgentMessage {
AgentMessageContent::ToolUse(value) => { AgentMessageContent::ToolUse(value) => {
language_model::MessageContent::ToolUse(value.clone()) language_model::MessageContent::ToolUse(value.clone())
} }
AgentMessageContent::Image(value) => {
language_model::MessageContent::Image(value.clone())
}
}; };
assistant_message.content.push(chunk); assistant_message.content.push(chunk);
} }
@ -432,14 +440,14 @@ pub enum AgentMessageContent {
signature: Option<String>, signature: Option<String>,
}, },
RedactedThinking(String), RedactedThinking(String),
Image(LanguageModelImage),
ToolUse(LanguageModelToolUse), ToolUse(LanguageModelToolUse),
} }
#[derive(Debug)] #[derive(Debug)]
pub enum AgentResponseEvent { pub enum ThreadEvent {
Text(String), UserMessage(UserMessage),
Thinking(String), AgentText(String),
AgentThinking(String),
ToolCall(acp::ToolCall), ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization), ToolCallAuthorization(ToolCallAuthorization),
@ -504,6 +512,121 @@ impl Thread {
} }
} }
pub fn from_db(
id: ThreadId,
db_thread: DbThread,
project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
context_server_registry: Entity<ContextServerRegistry>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) -> Self {
let profile_id = db_thread
.profile
.unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
Self {
id,
prompt_id: PromptId::new(),
messages: db_thread.messages,
completion_mode: CompletionMode::Normal,
running_turn: None,
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
context_server_registry,
profile_id,
project_context,
templates,
model,
project,
action_log,
}
}
pub fn replay(&self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
let (tx, rx) = mpsc::unbounded();
let stream = ThreadEventStream(tx);
for message in &self.messages {
match message {
Message::User(user_message) => stream.send_user_message(&user_message),
Message::Agent(assistant_message) => {
for content in &assistant_message.content {
match content {
AgentMessageContent::Text(text) => stream.send_text(text),
AgentMessageContent::Thinking { text, .. } => {
stream.send_thinking(text)
}
AgentMessageContent::RedactedThinking(_) => {}
AgentMessageContent::ToolUse(tool_use) => {
self.replay_tool_call(
tool_use,
assistant_message.tool_results.get(&tool_use.id),
&stream,
cx,
);
}
}
}
}
Message::Resume => {}
}
}
rx
}
fn replay_tool_call(
&self,
tool_use: &LanguageModelToolUse,
tool_result: Option<&LanguageModelToolResult>,
stream: &ThreadEventStream,
cx: &mut Context<Self>,
) {
let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
stream
.0
.unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
id: acp::ToolCallId(tool_use.id.to_string().into()),
title: tool_use.name.to_string(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::Failed,
content: Vec::new(),
locations: Vec::new(),
raw_input: Some(tool_use.input.clone()),
raw_output: None,
})))
.ok();
return;
};
let title = tool.initial_title(tool_use.input.clone());
let kind = tool.kind();
stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
if let Some(output) = tool_result
.as_ref()
.and_then(|result| result.output.clone())
{
let tool_event_stream = ToolCallEventStream::new(
tool_use.id.clone(),
stream.clone(),
Some(self.project.read(cx).fs().clone()),
);
tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
.log_err();
} else {
stream.update_tool_call_fields(
&tool_use.id,
acp::ToolCallUpdateFields {
content: Some(vec![TOOL_CANCELED_MESSAGE.into()]),
status: Some(acp::ToolCallStatus::Failed),
..Default::default()
},
);
}
}
pub fn project(&self) -> &Entity<Project> { pub fn project(&self) -> &Entity<Project> {
&self.project &self.project
} }
@ -574,7 +697,7 @@ impl Thread {
pub fn resume( pub fn resume(
&mut self, &mut self,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> { ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
anyhow::ensure!( anyhow::ensure!(
self.tool_use_limit_reached, self.tool_use_limit_reached,
"can only resume after tool use limit is reached" "can only resume after tool use limit is reached"
@ -595,7 +718,7 @@ impl Thread {
id: UserMessageId, id: UserMessageId,
content: impl IntoIterator<Item = T>, content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
where where
T: Into<UserMessageContent>, T: Into<UserMessageContent>,
{ {
@ -613,15 +736,12 @@ impl Thread {
self.run_turn(cx) self.run_turn(cx)
} }
fn run_turn( fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
&mut self,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
self.cancel(); self.cancel();
let model = self.model.clone(); let model = self.model.clone();
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>(); let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
let event_stream = AgentResponseEventStream(events_tx); let event_stream = ThreadEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1); let message_ix = self.messages.len().saturating_sub(1);
self.tool_use_limit_reached = false; self.tool_use_limit_reached = false;
self.running_turn = Some(RunningTurn { self.running_turn = Some(RunningTurn {
@ -755,7 +875,7 @@ impl Thread {
fn handle_streamed_completion_event( fn handle_streamed_completion_event(
&mut self, &mut self,
event: LanguageModelCompletionEvent, event: LanguageModelCompletionEvent,
event_stream: &AgentResponseEventStream, event_stream: &ThreadEventStream,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> { ) -> Option<Task<LanguageModelToolResult>> {
log::trace!("Handling streamed completion event: {:?}", event); log::trace!("Handling streamed completion event: {:?}", event);
@ -797,7 +917,7 @@ impl Thread {
fn handle_text_event( fn handle_text_event(
&mut self, &mut self,
new_text: String, new_text: String,
event_stream: &AgentResponseEventStream, event_stream: &ThreadEventStream,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
event_stream.send_text(&new_text); event_stream.send_text(&new_text);
@ -818,7 +938,7 @@ impl Thread {
&mut self, &mut self,
new_text: String, new_text: String,
new_signature: Option<String>, new_signature: Option<String>,
event_stream: &AgentResponseEventStream, event_stream: &ThreadEventStream,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
event_stream.send_thinking(&new_text); event_stream.send_thinking(&new_text);
@ -850,7 +970,7 @@ impl Thread {
fn handle_tool_use_event( fn handle_tool_use_event(
&mut self, &mut self,
tool_use: LanguageModelToolUse, tool_use: LanguageModelToolUse,
event_stream: &AgentResponseEventStream, event_stream: &ThreadEventStream,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> { ) -> Option<Task<LanguageModelToolResult>> {
cx.notify(); cx.notify();
@ -989,9 +1109,7 @@ impl Thread {
tool_use_id: tool_use.id.clone(), tool_use_id: tool_use.id.clone(),
tool_name: tool_use.name.clone(), tool_name: tool_use.name.clone(),
is_error: true, is_error: true,
content: LanguageModelToolResultContent::Text( content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
"Tool canceled by user".into(),
),
output: None, output: None,
}, },
); );
@ -1143,7 +1261,7 @@ struct RunningTurn {
_task: Task<()>, _task: Task<()>,
/// The current event stream for the running turn. Used to report a final /// The current event stream for the running turn. Used to report a final
/// cancellation event if we cancel the turn. /// cancellation event if we cancel the turn.
event_stream: AgentResponseEventStream, event_stream: ThreadEventStream,
} }
impl RunningTurn { impl RunningTurn {
@ -1196,6 +1314,17 @@ where
cx: &mut App, cx: &mut App,
) -> Task<Result<Self::Output>>; ) -> Task<Result<Self::Output>>;
/// Emits events for a previous execution of the tool.
fn replay(
&self,
_input: Self::Input,
_output: Self::Output,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
Ok(())
}
fn erase(self) -> Arc<dyn AnyAgentTool> { fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self))) Arc::new(Erased(Arc::new(self)))
} }
@ -1223,6 +1352,13 @@ pub trait AnyAgentTool {
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<AgentToolOutput>>; ) -> Task<Result<AgentToolOutput>>;
fn replay(
&self,
input: serde_json::Value,
output: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()>;
} }
impl<T> AnyAgentTool for Erased<Arc<T>> impl<T> AnyAgentTool for Erased<Arc<T>>
@ -1274,21 +1410,39 @@ where
}) })
}) })
} }
fn replay(
&self,
input: serde_json::Value,
output: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()> {
let input = serde_json::from_value(input)?;
let output = serde_json::from_value(output)?;
self.0.replay(input, output, event_stream, cx)
}
} }
#[derive(Clone)] #[derive(Clone)]
struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>); struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
impl ThreadEventStream {
fn send_user_message(&self, message: &UserMessage) {
self.0
.unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
.ok();
}
impl AgentResponseEventStream {
fn send_text(&self, text: &str) { fn send_text(&self, text: &str) {
self.0 self.0
.unbounded_send(Ok(AgentResponseEvent::Text(text.to_string()))) .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
.ok(); .ok();
} }
fn send_thinking(&self, text: &str) { fn send_thinking(&self, text: &str) {
self.0 self.0
.unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string()))) .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
.ok(); .ok();
} }
@ -1300,7 +1454,7 @@ impl AgentResponseEventStream {
input: serde_json::Value, input: serde_json::Value,
) { ) {
self.0 self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call( .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
id, id,
title.to_string(), title.to_string(),
kind, kind,
@ -1333,7 +1487,7 @@ impl AgentResponseEventStream {
fields: acp::ToolCallUpdateFields, fields: acp::ToolCallUpdateFields,
) { ) {
self.0 self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
acp::ToolCallUpdate { acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.to_string().into()), id: acp::ToolCallId(tool_use_id.to_string().into()),
fields, fields,
@ -1347,17 +1501,17 @@ impl AgentResponseEventStream {
match reason { match reason {
StopReason::EndTurn => { StopReason::EndTurn => {
self.0 self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn))) .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
.ok(); .ok();
} }
StopReason::MaxTokens => { StopReason::MaxTokens => {
self.0 self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens))) .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
.ok(); .ok();
} }
StopReason::Refusal => { StopReason::Refusal => {
self.0 self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal))) .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
.ok(); .ok();
} }
StopReason::ToolUse => {} StopReason::ToolUse => {}
@ -1366,7 +1520,7 @@ impl AgentResponseEventStream {
fn send_canceled(&self) { fn send_canceled(&self) {
self.0 self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
.ok(); .ok();
} }
@ -1378,24 +1532,23 @@ impl AgentResponseEventStream {
#[derive(Clone)] #[derive(Clone)]
pub struct ToolCallEventStream { pub struct ToolCallEventStream {
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
stream: AgentResponseEventStream, stream: ThreadEventStream,
fs: Option<Arc<dyn Fs>>, fs: Option<Arc<dyn Fs>>,
} }
impl ToolCallEventStream { impl ToolCallEventStream {
#[cfg(test)] #[cfg(test)]
pub fn test() -> (Self, ToolCallEventStreamReceiver) { pub fn test() -> (Self, ToolCallEventStreamReceiver) {
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>(); let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
let stream = let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None);
(stream, ToolCallEventStreamReceiver(events_rx)) (stream, ToolCallEventStreamReceiver(events_rx))
} }
fn new( fn new(
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
stream: AgentResponseEventStream, stream: ThreadEventStream,
fs: Option<Arc<dyn Fs>>, fs: Option<Arc<dyn Fs>>,
) -> Self { ) -> Self {
Self { Self {
@ -1413,7 +1566,7 @@ impl ToolCallEventStream {
pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) { pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
self.stream self.stream
.0 .0
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
acp_thread::ToolCallUpdateDiff { acp_thread::ToolCallUpdateDiff {
id: acp::ToolCallId(self.tool_use_id.to_string().into()), id: acp::ToolCallId(self.tool_use_id.to_string().into()),
diff, diff,
@ -1426,7 +1579,7 @@ impl ToolCallEventStream {
pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) { pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
self.stream self.stream
.0 .0
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
acp_thread::ToolCallUpdateTerminal { acp_thread::ToolCallUpdateTerminal {
id: acp::ToolCallId(self.tool_use_id.to_string().into()), id: acp::ToolCallId(self.tool_use_id.to_string().into()),
terminal, terminal,
@ -1444,7 +1597,7 @@ impl ToolCallEventStream {
let (response_tx, response_rx) = oneshot::channel(); let (response_tx, response_rx) = oneshot::channel();
self.stream self.stream
.0 .0
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
ToolCallAuthorization { ToolCallAuthorization {
tool_call: acp::ToolCallUpdate { tool_call: acp::ToolCallUpdate {
id: acp::ToolCallId(self.tool_use_id.to_string().into()), id: acp::ToolCallId(self.tool_use_id.to_string().into()),
@ -1494,13 +1647,13 @@ impl ToolCallEventStream {
} }
#[cfg(test)] #[cfg(test)]
pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>); pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
#[cfg(test)] #[cfg(test)]
impl ToolCallEventStreamReceiver { impl ToolCallEventStreamReceiver {
pub async fn expect_authorization(&mut self) -> ToolCallAuthorization { pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
let event = self.0.next().await; let event = self.0.next().await;
if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event { if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
auth auth
} else { } else {
panic!("Expected ToolCallAuthorization but got: {:?}", event); panic!("Expected ToolCallAuthorization but got: {:?}", event);
@ -1509,9 +1662,9 @@ impl ToolCallEventStreamReceiver {
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> { pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
let event = self.0.next().await; let event = self.0.next().await;
if let Some(Ok(AgentResponseEvent::ToolCallUpdate( if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
acp_thread::ToolCallUpdate::UpdateTerminal(update), update,
))) = event )))) = event
{ {
update.terminal update.terminal
} else { } else {
@ -1522,7 +1675,7 @@ impl ToolCallEventStreamReceiver {
#[cfg(test)] #[cfg(test)]
impl std::ops::Deref for ToolCallEventStreamReceiver { impl std::ops::Deref for ToolCallEventStreamReceiver {
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>; type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0

View file

@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
}) })
}) })
} }
fn replay(
&self,
_input: serde_json::Value,
_output: serde_json::Value,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
Ok(())
}
} }

View file

@ -319,7 +319,7 @@ mod tests {
use theme::ThemeSettings; use theme::ThemeSettings;
use util::test::TempTree; use util::test::TempTree;
use crate::AgentResponseEvent; use crate::ThreadEvent;
use super::*; use super::*;
@ -396,7 +396,7 @@ mod tests {
}); });
cx.run_until_parked(); cx.run_until_parked();
let event = stream_rx.try_next(); let event = stream_rx.try_next();
if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event { if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
auth.response.send(auth.options[0].id.clone()).unwrap(); auth.response.send(auth.options[0].id.clone()).unwrap();
} }