From 6c255c19736389916e8862f339464ae319dc9019 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 19 Aug 2025 16:24:23 +0200 Subject: [PATCH] Lay the groundwork to support history in agent2 (#36483) This pull request introduces title generation and history replaying. We still need to wire up the rest of the history but this gets us very close. I extracted a lot of this code from `agent2-history` because that branch was starting to get long-lived and there were lots of changes since we started. Release Notes: - N/A --- Cargo.lock | 3 + crates/acp_thread/src/acp_thread.rs | 39 +- crates/acp_thread/src/connection.rs | 16 +- crates/acp_thread/src/diff.rs | 13 +- crates/acp_thread/src/mention.rs | 3 +- crates/agent2/Cargo.toml | 4 + crates/agent2/src/agent.rs | 104 +-- crates/agent2/src/tests/mod.rs | 94 ++- crates/agent2/src/thread.rs | 667 ++++++++++++++---- .../src/tools/context_server_registry.rs | 10 + crates/agent2/src/tools/edit_file_tool.rs | 137 ++-- crates/agent2/src/tools/terminal_tool.rs | 4 +- crates/agent2/src/tools/web_search_tool.rs | 67 +- crates/agent_servers/Cargo.toml | 1 + crates/agent_servers/src/acp/v0.rs | 4 +- crates/agent_servers/src/acp/v1.rs | 7 +- crates/agent_servers/src/claude.rs | 12 +- crates/agent_ui/src/acp/thread_view.rs | 31 +- crates/agent_ui/src/agent_diff.rs | 41 +- 19 files changed, 929 insertions(+), 328 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d7edc54257..dc9d074f01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -191,6 +191,7 @@ version = "0.1.0" dependencies = [ "acp_thread", "action_log", + "agent", "agent-client-protocol", "agent_servers", "agent_settings", @@ -208,6 +209,7 @@ dependencies = [ "env_logger 0.11.8", "fs", "futures 0.3.31", + "git", "gpui", "gpui_tokio", "handlebars 4.5.0", @@ -256,6 +258,7 @@ name = "agent_servers" version = "0.1.0" dependencies = [ "acp_thread", + "action_log", "agent-client-protocol", "agent_settings", "agentic-coding-protocol", diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 227ca984d4..7d70727252 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -537,9 +537,15 @@ impl ToolCallContent { acp::ToolCallContent::Content { content } => { Self::ContentBlock(ContentBlock::new(content, &language_registry, cx)) } - acp::ToolCallContent::Diff { diff } => { - Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx))) - } + acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| { + Diff::finalized( + diff.path, + diff.old_text, + diff.new_text, + language_registry, + cx, + ) + })), } } @@ -682,6 +688,7 @@ pub struct AcpThread { #[derive(Debug)] pub enum AcpThreadEvent { NewEntry, + TitleUpdated, EntryUpdated(usize), EntriesRemoved(Range), ToolAuthorizationRequired, @@ -728,11 +735,9 @@ impl AcpThread { title: impl Into, connection: Rc, project: Entity, + action_log: Entity, session_id: acp::SessionId, - cx: &mut Context, ) -> Self { - let action_log = cx.new(|_| ActionLog::new(project.clone())); - Self { action_log, shared_buffers: Default::default(), @@ -926,6 +931,12 @@ impl AcpThread { cx.emit(AcpThreadEvent::NewEntry); } + pub fn update_title(&mut self, title: SharedString, cx: &mut Context) -> Result<()> { + self.title = title; + cx.emit(AcpThreadEvent::TitleUpdated); + Ok(()) + } + pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context) { cx.emit(AcpThreadEvent::Retry(status)); } @@ -1657,7 +1668,7 @@ mod tests { use super::*; use anyhow::anyhow; use futures::{channel::mpsc, future::LocalBoxFuture, select}; - use gpui::{AsyncApp, TestAppContext, WeakEntity}; + use gpui::{App, AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; use project::{FakeFs, Fs}; use rand::Rng as _; @@ -2327,7 +2338,7 @@ mod tests { self: Rc, project: Entity, _cwd: &Path, - cx: &mut gpui::App, + cx: &mut App, ) -> Task>> { let session_id = acp::SessionId( rand::thread_rng() @@ -2337,8 +2348,16 @@ mod tests { .collect::() .into(), ); - let thread = - cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let thread = cx.new(|_cx| { + AcpThread::new( + "Test", + self.clone(), + project, + action_log, + session_id.clone(), + ) + }); self.sessions.lock().insert(session_id, thread.downgrade()); Task::ready(Ok(thread)) } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 0d4116321d..b09f383029 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -5,11 +5,12 @@ use collections::IndexMap; use gpui::{Entity, SharedString, Task}; use language_model::LanguageModelProviderId; use project::Project; +use serde::{Deserialize, Serialize}; use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; use uuid::Uuid; -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UserMessageId(Arc); impl UserMessageId { @@ -208,6 +209,7 @@ impl AgentModelList { mod test_support { use std::sync::Arc; + use action_log::ActionLog; use collections::HashMap; use futures::{channel::oneshot, future::try_join_all}; use gpui::{AppContext as _, WeakEntity}; @@ -295,8 +297,16 @@ mod test_support { cx: &mut gpui::App, ) -> Task>> { let session_id = acp::SessionId(self.sessions.lock().len().to_string().into()); - let thread = - cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let thread = cx.new(|_cx| { + AcpThread::new( + "Test", + self.clone(), + project, + action_log, + session_id.clone(), + ) + }); self.sessions.lock().insert( session_id, Session { diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs index e5f71d2109..4b779931c5 100644 --- a/crates/acp_thread/src/diff.rs +++ b/crates/acp_thread/src/diff.rs @@ -1,4 +1,3 @@ -use agent_client_protocol as acp; use anyhow::Result; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::{MultiBuffer, PathKey}; @@ -21,17 +20,13 @@ pub enum Diff { } impl Diff { - pub fn from_acp( - diff: acp::Diff, + pub fn finalized( + path: PathBuf, + old_text: Option, + new_text: String, language_registry: Arc, cx: &mut Context, ) -> Self { - let acp::Diff { - path, - old_text, - new_text, - } = diff; - let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); diff --git a/crates/acp_thread/src/mention.rs b/crates/acp_thread/src/mention.rs index 25e64acbee..350785ec1e 100644 --- a/crates/acp_thread/src/mention.rs +++ b/crates/acp_thread/src/mention.rs @@ -2,6 +2,7 @@ use agent::ThreadId; use anyhow::{Context as _, Result, bail}; use file_icons::FileIcons; use prompt_store::{PromptId, UserPromptId}; +use serde::{Deserialize, Serialize}; use std::{ fmt, ops::Range, @@ -11,7 +12,7 @@ use std::{ use ui::{App, IconName, SharedString}; use url::Url; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum MentionUri { File { abs_path: PathBuf, diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index ac1840e5e5..8129341545 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] acp_thread.workspace = true action_log.workspace = true +agent.workspace = true agent-client-protocol.workspace = true agent_servers.workspace = true agent_settings.workspace = true @@ -26,6 +27,7 @@ collections.workspace = true context_server.workspace = true fs.workspace = true futures.workspace = true +git.workspace = true gpui.workspace = true handlebars = { workspace = true, features = ["rust-embed"] } html_to_markdown.workspace = true @@ -59,6 +61,7 @@ which.workspace = true workspace-hack.workspace = true [dev-dependencies] +agent = { workspace = true, "features" = ["test-support"] } ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } @@ -66,6 +69,7 @@ context_server = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } +git = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } gpui_tokio.workspace = true language = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 480b2baa95..9cf0c3b603 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,10 +1,11 @@ use crate::{ - AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, - DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, - MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, - ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, + ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool, + EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, + OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, + UserMessageContent, WebSearchTool, templates::Templates, }; use acp_thread::AgentModelSelector; +use action_log::ActionLog; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; @@ -427,18 +428,19 @@ impl NativeAgent { ) { self.models.refresh_list(cx); - let default_model = LanguageModelRegistry::read_global(cx) - .default_model() - .map(|m| m.model.clone()); + let registry = LanguageModelRegistry::read_global(cx); + let default_model = registry.default_model().map(|m| m.model.clone()); + let summarization_model = registry.thread_summary_model().map(|m| m.model.clone()); for session in self.sessions.values_mut() { session.thread.update(cx, |thread, cx| { if thread.model().is_none() && let Some(model) = default_model.clone() { - thread.set_model(model); + thread.set_model(model, cx); cx.notify(); } + thread.set_summarization_model(summarization_model.clone(), cx); }); } } @@ -462,10 +464,7 @@ impl NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, f: impl 'static - + FnOnce( - Entity, - &mut App, - ) -> Result>>, + + FnOnce(Entity, &mut App) -> Result>>, ) -> Task> { let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { agent @@ -489,7 +488,18 @@ impl NativeAgentConnection { log::trace!("Received completion event: {:?}", event); match event { - AgentResponseEvent::Text(text) => { + ThreadEvent::UserMessage(message) => { + acp_thread.update(cx, |thread, cx| { + for content in message.content { + thread.push_user_content_block( + Some(message.id.clone()), + content.into(), + cx, + ); + } + })?; + } + ThreadEvent::AgentText(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block( acp::ContentBlock::Text(acp::TextContent { @@ -501,7 +511,7 @@ impl NativeAgentConnection { ) })?; } - AgentResponseEvent::Thinking(text) => { + ThreadEvent::AgentThinking(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block( acp::ContentBlock::Text(acp::TextContent { @@ -513,7 +523,7 @@ impl NativeAgentConnection { ) })?; } - AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { + ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { tool_call, options, response, @@ -536,22 +546,26 @@ impl NativeAgentConnection { }) .detach(); } - AgentResponseEvent::ToolCall(tool_call) => { + ThreadEvent::ToolCall(tool_call) => { acp_thread.update(cx, |thread, cx| { thread.upsert_tool_call(tool_call, cx) })??; } - AgentResponseEvent::ToolCallUpdate(update) => { + ThreadEvent::ToolCallUpdate(update) => { acp_thread.update(cx, |thread, cx| { thread.update_tool_call(update, cx) })??; } - AgentResponseEvent::Retry(status) => { + ThreadEvent::TitleUpdate(title) => { + acp_thread + .update(cx, |thread, cx| thread.update_title(title, cx))??; + } + ThreadEvent::Retry(status) => { acp_thread.update(cx, |thread, cx| { thread.update_retry_status(status, cx) })?; } - AgentResponseEvent::Stop(stop_reason) => { + ThreadEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse { stop_reason }); } @@ -604,8 +618,8 @@ impl AgentModelSelector for NativeAgentConnection { return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); }; - thread.update(cx, |thread, _cx| { - thread.set_model(model.clone()); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); }); update_settings_file::( @@ -665,30 +679,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection { cx.spawn(async move |cx| { log::debug!("Starting thread creation in async context"); - // Generate session ID - let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); - log::info!("Created session with ID: {}", session_id); - - // Create AcpThread - let acp_thread = cx.update(|cx| { - cx.new(|cx| { - acp_thread::AcpThread::new( - "agent2", - self.clone(), - project.clone(), - session_id.clone(), - cx, - ) - }) - })?; - let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; - + let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?; // Create Thread let thread = agent.update( cx, |agent, cx: &mut gpui::Context| -> Result<_> { // Fetch default model from registry settings let registry = LanguageModelRegistry::read_global(cx); + let language_registry = project.read(cx).languages().clone(); // Log available models for debugging let available_count = registry.available_models(cx).count(); @@ -699,6 +697,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .models .model_from_id(&LanguageModels::model_id(&default_model.model)) }); + let summarization_model = registry.thread_summary_model().map(|c| c.model); let thread = cx.new(|cx| { let mut thread = Thread::new( @@ -708,13 +707,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection { action_log.clone(), agent.templates.clone(), default_model, + summarization_model, cx, ); thread.add_tool(CopyPathTool::new(project.clone())); thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(EditFileTool::new(cx.entity())); + thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry)); thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(GrepTool::new(project.clone())); @@ -722,7 +722,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { thread.add_tool(MovePathTool::new(project.clone())); thread.add_tool(NowTool); thread.add_tool(OpenTool::new(project.clone())); - thread.add_tool(ReadFileTool::new(project.clone(), action_log)); + thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone())); thread.add_tool(TerminalTool::new(project.clone(), cx)); thread.add_tool(ThinkingTool); thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. @@ -733,6 +733,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection { }, )??; + let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?; + log::info!("Created session with ID: {}", session_id); + // Create AcpThread + let acp_thread = cx.update(|cx| { + cx.new(|_cx| { + acp_thread::AcpThread::new( + "agent2", + self.clone(), + project.clone(), + action_log.clone(), + session_id.clone(), + ) + }) + })?; + // Store the session agent.update(cx, |agent, cx| { agent.sessions.insert( @@ -803,7 +818,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::info!("Cancelling on session: {}", session_id); self.0.update(cx, |agent, cx| { if let Some(agent) = agent.sessions.get(session_id) { - agent.thread.update(cx, |thread, _cx| thread.cancel()); + agent.thread.update(cx, |thread, cx| thread.cancel(cx)); } }); } @@ -830,7 +845,10 @@ struct NativeAgentSessionEditor(Entity); impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { - Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) + Task::ready( + self.0 + .update(cx, |thread, cx| thread.truncate(message_id, cx)), + ) } } diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index c83479f2cf..33706b05de 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -345,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { let mut saw_partial_tool_use = false; while let Some(event) = events.next().await { - if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { + if let Ok(ThreadEvent::ToolCall(tool_call)) = event { thread.update(cx, |thread, _cx| { // Look for a tool use in the thread's last message let message = thread.last_message().unwrap(); @@ -735,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { ); } -async fn expect_tool_call( - events: &mut UnboundedReceiver>, -) -> acp::ToolCall { +async fn expect_tool_call(events: &mut UnboundedReceiver>) -> acp::ToolCall { let event = events .next() .await .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCall(tool_call) => return tool_call, + ThreadEvent::ToolCall(tool_call) => return tool_call, event => { panic!("Unexpected event {event:?}"); } @@ -752,7 +750,7 @@ async fn expect_tool_call( } async fn expect_tool_call_update_fields( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCallUpdate { let event = events .next() @@ -760,7 +758,7 @@ async fn expect_tool_call_update_fields( .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { return update; } event => { @@ -770,7 +768,7 @@ async fn expect_tool_call_update_fields( } async fn next_tool_call_authorization( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> ToolCallAuthorization { loop { let event = events @@ -778,7 +776,7 @@ async fn next_tool_call_authorization( .await .expect("no tool call authorization event received") .unwrap(); - if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { + if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event { let permission_kinds = tool_call_authorization .options .iter() @@ -945,13 +943,13 @@ async fn test_cancellation(cx: &mut TestAppContext) { let mut echo_completed = false; while let Some(event) = events.next().await { match event.unwrap() { - AgentResponseEvent::ToolCall(tool_call) => { + ThreadEvent::ToolCall(tool_call) => { assert_eq!(tool_call.title, expected_tools.remove(0)); if tool_call.title == "Echo" { echo_id = Some(tool_call.id); } } - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( acp::ToolCallUpdate { id, fields: @@ -973,13 +971,13 @@ async fn test_cancellation(cx: &mut TestAppContext) { // Cancel the current send and ensure that the event stream is closed, even // if one of the tools is still running. - thread.update(cx, |thread, _cx| thread.cancel()); + thread.update(cx, |thread, cx| thread.cancel(cx)); let events = events.collect::>().await; let last_event = events.last(); assert!( matches!( last_event, - Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled))) ), "unexpected event {last_event:?}" ); @@ -1161,7 +1159,7 @@ async fn test_truncate(cx: &mut TestAppContext) { }); thread - .update(cx, |thread, _cx| thread.truncate(message_id)) + .update(cx, |thread, cx| thread.truncate(message_id, cx)) .unwrap(); cx.run_until_parked(); thread.read_with(cx, |thread, _| { @@ -1203,6 +1201,51 @@ async fn test_truncate(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_title_generation(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let summary_model = Arc::new(FakeLanguageModel::default()); + thread.update(cx, |thread, cx| { + thread.set_summarization_model(Some(summary_model.clone()), cx) + }); + + let send = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread")); + + // Ensure the summary model has been invoked to generate a title. + summary_model.send_last_completion_stream_text_chunk("Hello "); + summary_model.send_last_completion_stream_text_chunk("world\nG"); + summary_model.send_last_completion_stream_text_chunk("oodnight Moon"); + summary_model.end_last_completion_stream(); + send.collect::>().await; + thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world")); + + // Send another message, ensuring no title is generated this time. + let send = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello again"], cx) + }) + .unwrap(); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey again!"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + assert_eq!(summary_model.pending_completions(), Vec::new()); + send.collect::>().await; + thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world")); +} + #[gpui::test] async fn test_agent_connection(cx: &mut TestAppContext) { cx.update(settings::init); @@ -1442,7 +1485,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { - thread.set_completion_mode(agent_settings::CompletionMode::Burn); + thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); thread.send(UserMessageId::new(), ["Hello!"], cx) }) .unwrap(); @@ -1454,10 +1497,10 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) { let mut retry_events = Vec::new(); while let Some(Ok(event)) = events.next().await { match event { - AgentResponseEvent::Retry(retry_status) => { + ThreadEvent::Retry(retry_status) => { retry_events.push(retry_status); } - AgentResponseEvent::Stop(..) => break, + ThreadEvent::Stop(..) => break, _ => {} } } @@ -1486,7 +1529,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { - thread.set_completion_mode(agent_settings::CompletionMode::Burn); + thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); thread.send(UserMessageId::new(), ["Hello!"], cx) }) .unwrap(); @@ -1507,10 +1550,10 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) { let mut retry_events = Vec::new(); while let Some(Ok(event)) = events.next().await { match event { - AgentResponseEvent::Retry(retry_status) => { + ThreadEvent::Retry(retry_status) => { retry_events.push(retry_status); } - AgentResponseEvent::Stop(..) => break, + ThreadEvent::Stop(..) => break, _ => {} } } @@ -1543,7 +1586,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { - thread.set_completion_mode(agent_settings::CompletionMode::Burn); + thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); thread.send(UserMessageId::new(), ["Hello!"], cx) }) .unwrap(); @@ -1565,10 +1608,10 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) { let mut retry_events = Vec::new(); while let Some(event) = events.next().await { match event { - Ok(AgentResponseEvent::Retry(retry_status)) => { + Ok(ThreadEvent::Retry(retry_status)) => { retry_events.push(retry_status); } - Ok(AgentResponseEvent::Stop(..)) => break, + Ok(ThreadEvent::Stop(..)) => break, Err(error) => errors.push(error), _ => {} } @@ -1592,11 +1635,11 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) { } /// Filters out the stop events for asserting against in tests -fn stop_events(result_events: Vec>) -> Vec { +fn stop_events(result_events: Vec>) -> Vec { result_events .into_iter() .filter_map(|event| match event.unwrap() { - AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), + ThreadEvent::Stop(stop_reason) => Some(stop_reason), _ => None, }) .collect() @@ -1713,6 +1756,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { action_log, templates, Some(model.clone()), + None, cx, ) }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 856e70ce59..aeb600e232 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,25 +1,34 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; +use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot}; use agent_client_protocol as acp; -use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; +use chrono::{DateTime, Utc}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus}; use collections::IndexMap; use fs::Fs; use futures::{ + FutureExt, channel::{mpsc, oneshot}, + future::Shared, stream::FuturesUnordered, }; +use git::repository::DiffType; use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, + TokenUsage, +}; +use project::{ + Project, + git_store::{GitStore, RepositoryState}, }; -use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; @@ -35,28 +44,7 @@ use std::{fmt::Write, ops::Range}; use util::{ResultExt, markdown::MarkdownCodeBlock}; use uuid::Uuid; -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, -)] -pub struct ThreadId(Arc); - -impl ThreadId { - pub fn new() -> Self { - Self(Uuid::new_v4().to_string().into()) - } -} - -impl std::fmt::Display for ThreadId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From<&str> for ThreadId { - fn from(value: &str) -> Self { - Self(value.into()) - } -} +const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; /// The ID of the user prompt that initiated a request. /// @@ -91,7 +79,7 @@ enum RetryStrategy { }, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Message { User(UserMessage), Agent(AgentMessage), @@ -106,6 +94,18 @@ impl Message { } } + pub fn to_request(&self) -> Vec { + match self { + Message::User(message) => vec![message.to_request()], + Message::Agent(message) => message.to_request(), + Message::Resume => vec![LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: false, + }], + } + } + pub fn to_markdown(&self) -> String { match self { Message::User(message) => message.to_markdown(), @@ -113,15 +113,22 @@ impl Message { Message::Resume => "[resumed after tool use limit was reached]".into(), } } + + pub fn role(&self) -> Role { + match self { + Message::User(_) | Message::Resume => Role::User, + Message::Agent(_) => Role::Assistant, + } + } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UserMessage { pub id: UserMessageId, pub content: Vec, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum UserMessageContent { Text(String), Mention { uri: MentionUri, content: String }, @@ -345,9 +352,6 @@ impl AgentMessage { AgentMessageContent::RedactedThinking(_) => { markdown.push_str("\n") } - AgentMessageContent::Image(_) => { - markdown.push_str("\n"); - } AgentMessageContent::ToolUse(tool_use) => { markdown.push_str(&format!( "**Tool Use**: {} (ID: {})\n", @@ -418,9 +422,6 @@ impl AgentMessage { AgentMessageContent::ToolUse(value) => { language_model::MessageContent::ToolUse(value.clone()) } - AgentMessageContent::Image(value) => { - language_model::MessageContent::Image(value.clone()) - } }; assistant_message.content.push(chunk); } @@ -450,13 +451,13 @@ impl AgentMessage { } } -#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AgentMessage { pub content: Vec, pub tool_results: IndexMap, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum AgentMessageContent { Text(String), Thinking { @@ -464,17 +465,18 @@ pub enum AgentMessageContent { signature: Option, }, RedactedThinking(String), - Image(LanguageModelImage), ToolUse(LanguageModelToolUse), } #[derive(Debug)] -pub enum AgentResponseEvent { - Text(String), - Thinking(String), +pub enum ThreadEvent { + UserMessage(UserMessage), + AgentText(String), + AgentThinking(String), ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + TitleUpdate(SharedString), Retry(acp_thread::RetryStatus), Stop(acp::StopReason), } @@ -487,8 +489,12 @@ pub struct ToolCallAuthorization { } pub struct Thread { - id: ThreadId, + id: acp::SessionId, prompt_id: PromptId, + updated_at: DateTime, + title: Option, + #[allow(unused)] + summary: DetailedSummaryState, messages: Vec, completion_mode: CompletionMode, /// Holds the task that handles agent interaction until the end of the turn. @@ -498,11 +504,18 @@ pub struct Thread { pending_message: Option, tools: BTreeMap>, tool_use_limit_reached: bool, + #[allow(unused)] + request_token_usage: Vec, + #[allow(unused)] + cumulative_token_usage: TokenUsage, + #[allow(unused)] + initial_project_snapshot: Shared>>>, context_server_registry: Entity, profile_id: AgentProfileId, project_context: Entity, templates: Arc, model: Option>, + summarization_model: Option>, project: Entity, action_log: Entity, } @@ -515,36 +528,254 @@ impl Thread { action_log: Entity, templates: Arc, model: Option>, + summarization_model: Option>, cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { - id: ThreadId::new(), + id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()), prompt_id: PromptId::new(), + updated_at: Utc::now(), + title: None, + summary: DetailedSummaryState::default(), messages: Vec::new(), completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, running_turn: None, pending_message: None, tools: BTreeMap::default(), tool_use_limit_reached: false, + request_token_usage: Vec::new(), + cumulative_token_usage: TokenUsage::default(), + initial_project_snapshot: { + let project_snapshot = Self::project_snapshot(project.clone(), cx); + cx.foreground_executor() + .spawn(async move { Some(project_snapshot.await) }) + .shared() + }, context_server_registry, profile_id, project_context, templates, model, + summarization_model, project, action_log, } } - pub fn project(&self) -> &Entity { - &self.project + pub fn id(&self) -> &acp::SessionId { + &self.id + } + + pub fn replay( + &mut self, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let (tx, rx) = mpsc::unbounded(); + let stream = ThreadEventStream(tx); + for message in &self.messages { + match message { + Message::User(user_message) => stream.send_user_message(user_message), + Message::Agent(assistant_message) => { + for content in &assistant_message.content { + match content { + AgentMessageContent::Text(text) => stream.send_text(text), + AgentMessageContent::Thinking { text, .. } => { + stream.send_thinking(text) + } + AgentMessageContent::RedactedThinking(_) => {} + AgentMessageContent::ToolUse(tool_use) => { + self.replay_tool_call( + tool_use, + assistant_message.tool_results.get(&tool_use.id), + &stream, + cx, + ); + } + } + } + } + Message::Resume => {} + } + } + rx + } + + fn replay_tool_call( + &self, + tool_use: &LanguageModelToolUse, + tool_result: Option<&LanguageModelToolResult>, + stream: &ThreadEventStream, + cx: &mut Context, + ) { + let Some(tool) = self.tools.get(tool_use.name.as_ref()) else { + stream + .0 + .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + title: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Failed, + content: Vec::new(), + locations: Vec::new(), + raw_input: Some(tool_use.input.clone()), + raw_output: None, + }))) + .ok(); + return; + }; + + let title = tool.initial_title(tool_use.input.clone()); + let kind = tool.kind(); + stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + + let output = tool_result + .as_ref() + .and_then(|result| result.output.clone()); + if let Some(output) = output.clone() { + let tool_event_stream = ToolCallEventStream::new( + tool_use.id.clone(), + stream.clone(), + Some(self.project.read(cx).fs().clone()), + ); + tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) + .log_err(); + } + + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + raw_output: output, + ..Default::default() + }, + ); + } + + /// Create a snapshot of the current project state including git information and unsaved buffers. + fn project_snapshot( + project: Entity, + cx: &mut Context, + ) -> Task> { + let git_store = project.read(cx).git_store().clone(); + let worktree_snapshots: Vec<_> = project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) + .collect(); + + cx.spawn(async move |_, cx| { + let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; + + let mut unsaved_buffers = Vec::new(); + cx.update(|app_cx| { + let buffer_store = project.read(app_cx).buffer_store(); + for buffer_handle in buffer_store.read(app_cx).buffers() { + let buffer = buffer_handle.read(app_cx); + if buffer.is_dirty() + && let Some(file) = buffer.file() + { + let path = file.path().to_string_lossy().to_string(); + unsaved_buffers.push(path); + } + } + }) + .ok(); + + Arc::new(ProjectSnapshot { + worktree_snapshots, + unsaved_buffer_paths: unsaved_buffers, + timestamp: Utc::now(), + }) + }) + } + + fn worktree_snapshot( + worktree: Entity, + git_store: Entity, + cx: &App, + ) -> Task { + cx.spawn(async move |cx| { + // Get worktree path and snapshot + let worktree_info = cx.update(|app_cx| { + let worktree = worktree.read(app_cx); + let path = worktree.abs_path().to_string_lossy().to_string(); + let snapshot = worktree.snapshot(); + (path, snapshot) + }); + + let Ok((worktree_path, _snapshot)) = worktree_info else { + return WorktreeSnapshot { + worktree_path: String::new(), + git_state: None, + }; + }; + + let git_state = git_store + .update(cx, |git_store, cx| { + git_store + .repositories() + .values() + .find(|repo| { + repo.read(cx) + .abs_path_to_repo_path(&worktree.read(cx).abs_path()) + .is_some() + }) + .cloned() + }) + .ok() + .flatten() + .map(|repo| { + repo.update(cx, |repo, _| { + let current_branch = + repo.branch.as_ref().map(|branch| branch.name().to_owned()); + repo.send_job(None, |state, _| async move { + let RepositoryState::Local { backend, .. } = state else { + return GitState { + remote_url: None, + head_sha: None, + current_branch, + diff: None, + }; + }; + + let remote_url = backend.remote_url("origin"); + let head_sha = backend.head_sha().await; + let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); + + GitState { + remote_url, + head_sha, + current_branch, + diff, + } + }) + }) + }); + + let git_state = match git_state { + Some(git_state) => match git_state.ok() { + Some(git_state) => git_state.await.ok(), + None => None, + }, + None => None, + }; + + WorktreeSnapshot { + worktree_path, + git_state, + } + }) } pub fn project_context(&self) -> &Entity { &self.project_context } + pub fn project(&self) -> &Entity { + &self.project + } + pub fn action_log(&self) -> &Entity { &self.action_log } @@ -553,16 +784,27 @@ impl Thread { self.model.as_ref() } - pub fn set_model(&mut self, model: Arc) { + pub fn set_model(&mut self, model: Arc, cx: &mut Context) { self.model = Some(model); + cx.notify() + } + + pub fn set_summarization_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + self.summarization_model = model; + cx.notify() } pub fn completion_mode(&self) -> CompletionMode { self.completion_mode } - pub fn set_completion_mode(&mut self, mode: CompletionMode) { + pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context) { self.completion_mode = mode; + cx.notify() } #[cfg(any(test, feature = "test-support"))] @@ -590,29 +832,29 @@ impl Thread { self.profile_id = profile_id; } - pub fn cancel(&mut self) { + pub fn cancel(&mut self, cx: &mut Context) { if let Some(running_turn) = self.running_turn.take() { running_turn.cancel(); } - self.flush_pending_message(); + self.flush_pending_message(cx); } - pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { - self.cancel(); + pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { + self.cancel(cx); let Some(position) = self.messages.iter().position( |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), ) else { return Err(anyhow!("Message not found")); }; self.messages.truncate(position); + cx.notify(); Ok(()) } pub fn resume( &mut self, cx: &mut Context, - ) -> Result>> { - anyhow::ensure!(self.model.is_some(), "Model not set"); + ) -> Result>> { anyhow::ensure!( self.tool_use_limit_reached, "can only resume after tool use limit is reached" @@ -633,7 +875,7 @@ impl Thread { id: UserMessageId, content: impl IntoIterator, cx: &mut Context, - ) -> Result>> + ) -> Result>> where T: Into, { @@ -656,22 +898,19 @@ impl Thread { fn run_turn( &mut self, cx: &mut Context, - ) -> Result>> { - self.cancel(); + ) -> Result>> { + self.cancel(cx); - let model = self - .model() - .cloned() - .context("No language model configured")?; - let (events_tx, events_rx) = mpsc::unbounded::>(); - let event_stream = AgentResponseEventStream(events_tx); + let model = self.model.clone().context("No language model configured")?; + let (events_tx, events_rx) = mpsc::unbounded::>(); + let event_stream = ThreadEventStream(events_tx); let message_ix = self.messages.len().saturating_sub(1); self.tool_use_limit_reached = false; self.running_turn = Some(RunningTurn { event_stream: event_stream.clone(), _task: cx.spawn(async move |this, cx| { log::info!("Starting agent turn execution"); - let turn_result: Result<()> = async { + let turn_result: Result = async { let mut completion_intent = CompletionIntent::UserPrompt; loop { log::debug!( @@ -685,18 +924,27 @@ impl Thread { log::info!("Calling model.stream_completion"); let mut tool_use_limit_reached = false; + let mut refused = false; + let mut reached_max_tokens = false; let mut tool_uses = Self::stream_completion_with_retries( this.clone(), model.clone(), request, - message_ix, &event_stream, &mut tool_use_limit_reached, + &mut refused, + &mut reached_max_tokens, cx, ) .await?; - let used_tools = tool_uses.is_empty(); + if refused { + return Ok(StopReason::Refusal); + } else if reached_max_tokens { + return Ok(StopReason::MaxTokens); + } + + let end_turn = tool_uses.is_empty(); while let Some(tool_result) = tool_uses.next().await { log::info!("Tool finished {:?}", tool_result); @@ -724,29 +972,42 @@ impl Thread { log::info!("Tool use limit reached, completing turn"); this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?; return Err(language_model::ToolUseLimitReachedError.into()); - } else if used_tools { + } else if end_turn { log::info!("No tool uses found, completing turn"); - return Ok(()); + return Ok(StopReason::EndTurn); } else { - this.update(cx, |this, _| this.flush_pending_message())?; + this.update(cx, |this, cx| this.flush_pending_message(cx))?; completion_intent = CompletionIntent::ToolResults; } } } .await; + _ = this.update(cx, |this, cx| this.flush_pending_message(cx)); - if let Err(error) = turn_result { - log::error!("Turn execution failed: {:?}", error); - event_stream.send_error(error); - } else { - log::info!("Turn execution completed successfully"); + match turn_result { + Ok(reason) => { + log::info!("Turn execution completed: {:?}", reason); + + let update_title = this + .update(cx, |this, cx| this.update_title(&event_stream, cx)) + .ok() + .flatten(); + if let Some(update_title) = update_title { + update_title.await.context("update title failed").log_err(); + } + + event_stream.send_stop(reason); + if reason == StopReason::Refusal { + _ = this.update(cx, |this, _| this.messages.truncate(message_ix)); + } + } + Err(error) => { + log::error!("Turn execution failed: {:?}", error); + event_stream.send_error(error); + } } - this.update(cx, |this, _| { - this.flush_pending_message(); - this.running_turn.take(); - }) - .ok(); + _ = this.update(cx, |this, _| this.running_turn.take()); }), }); Ok(events_rx) @@ -756,9 +1017,10 @@ impl Thread { this: WeakEntity, model: Arc, request: LanguageModelRequest, - message_ix: usize, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, tool_use_limit_reached: &mut bool, + refusal: &mut bool, + max_tokens_reached: &mut bool, cx: &mut AsyncApp, ) -> Result>> { log::debug!("Stream completion started successfully"); @@ -774,16 +1036,17 @@ impl Thread { )) => { *tool_use_limit_reached = true; } - Ok(LanguageModelCompletionEvent::Stop(reason)) => { - event_stream.send_stop(reason); - if reason == StopReason::Refusal { - this.update(cx, |this, _cx| { - this.flush_pending_message(); - this.messages.truncate(message_ix); - })?; - return Ok(tool_uses); - } + Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => { + *refusal = true; + return Ok(FuturesUnordered::default()); } + Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => { + *max_tokens_reached = true; + return Ok(FuturesUnordered::default()); + } + Ok(LanguageModelCompletionEvent::Stop( + StopReason::ToolUse | StopReason::EndTurn, + )) => break, Ok(event) => { log::trace!("Received completion event: {:?}", event); this.update(cx, |this, cx| { @@ -843,6 +1106,7 @@ impl Thread { } } } + return Ok(tool_uses); } } @@ -870,7 +1134,7 @@ impl Thread { fn handle_streamed_completion_event( &mut self, event: LanguageModelCompletionEvent, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) -> Option> { log::trace!("Handling streamed completion event: {:?}", event); @@ -878,7 +1142,7 @@ impl Thread { match event { StartMessage { .. } => { - self.flush_pending_message(); + self.flush_pending_message(cx); self.pending_message = Some(AgentMessage::default()); } Text(new_text) => self.handle_text_event(new_text, event_stream, cx), @@ -912,7 +1176,7 @@ impl Thread { fn handle_text_event( &mut self, new_text: String, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { event_stream.send_text(&new_text); @@ -933,7 +1197,7 @@ impl Thread { &mut self, new_text: String, new_signature: Option, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { event_stream.send_thinking(&new_text); @@ -965,7 +1229,7 @@ impl Thread { fn handle_tool_use_event( &mut self, tool_use: LanguageModelToolUse, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) -> Option> { cx.notify(); @@ -1083,11 +1347,85 @@ impl Thread { } } + pub fn title(&self) -> SharedString { + self.title.clone().unwrap_or("New Thread".into()) + } + + fn update_title( + &mut self, + event_stream: &ThreadEventStream, + cx: &mut Context, + ) -> Option>> { + if self.title.is_some() { + log::debug!("Skipping title generation because we already have one."); + return None; + } + + log::info!( + "Generating title with model: {:?}", + self.summarization_model.as_ref().map(|model| model.name()) + ); + let model = self.summarization_model.clone()?; + let event_stream = event_stream.clone(); + let mut request = LanguageModelRequest { + intent: Some(CompletionIntent::ThreadSummarization), + temperature: AgentSettings::temperature_for_model(&model, cx), + ..Default::default() + }; + + for message in &self.messages { + request.messages.extend(message.to_request()); + } + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![SUMMARIZE_THREAD_PROMPT.into()], + cache: false, + }); + Some(cx.spawn(async move |this, cx| { + let mut title = String::new(); + let mut messages = model.stream_completion(request, cx).await?; + while let Some(event) = messages.next().await { + let event = event?; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { .. }, + ) => { + // this.update(cx, |thread, cx| { + // thread.update_model_request_usage(amount as u32, limit, cx); + // })?; + // TODO: handle usage update + continue; + } + _ => continue, + }; + + let mut lines = text.lines(); + title.extend(lines.next()); + + // Stop if the LLM generated multiple lines. + if lines.next().is_some() { + break; + } + } + + log::info!("Setting title: {}", title); + + this.update(cx, |this, cx| { + let title = SharedString::from(title); + event_stream.send_title_update(title.clone()); + this.title = Some(title); + cx.notify(); + }) + })) + } + fn pending_message(&mut self) -> &mut AgentMessage { self.pending_message.get_or_insert_default() } - fn flush_pending_message(&mut self) { + fn flush_pending_message(&mut self, cx: &mut Context) { let Some(mut message) = self.pending_message.take() else { return; }; @@ -1104,9 +1442,7 @@ impl Thread { tool_use_id: tool_use.id.clone(), tool_name: tool_use.name.clone(), is_error: true, - content: LanguageModelToolResultContent::Text( - "Tool canceled by user".into(), - ), + content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()), output: None, }, ); @@ -1114,6 +1450,8 @@ impl Thread { } self.messages.push(Message::Agent(message)); + self.updated_at = Utc::now(); + cx.notify() } pub(crate) fn build_completion_request( @@ -1205,15 +1543,7 @@ impl Thread { ); let mut messages = vec![self.build_system_message(cx)]; for message in &self.messages { - match message { - Message::User(message) => messages.push(message.to_request()), - Message::Agent(message) => messages.extend(message.to_request()), - Message::Resume => messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec!["Continue where you left off".into()], - cache: false, - }), - } + messages.extend(message.to_request()); } if let Some(message) = self.pending_message.as_ref() { @@ -1367,7 +1697,7 @@ struct RunningTurn { _task: Task<()>, /// The current event stream for the running turn. Used to report a final /// cancellation event if we cancel the turn. - event_stream: AgentResponseEventStream, + event_stream: ThreadEventStream, } impl RunningTurn { @@ -1420,6 +1750,17 @@ where cx: &mut App, ) -> Task>; + /// Emits events for a previous execution of the tool. + fn replay( + &self, + _input: Self::Input, + _output: Self::Output, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } + fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) } @@ -1447,6 +1788,13 @@ pub trait AnyAgentTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task>; + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()>; } impl AnyAgentTool for Erased> @@ -1498,21 +1846,45 @@ where }) }) } + + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + let input = serde_json::from_value(input)?; + let output = serde_json::from_value(output)?; + self.0.replay(input, output, event_stream, cx) + } } #[derive(Clone)] -struct AgentResponseEventStream(mpsc::UnboundedSender>); +struct ThreadEventStream(mpsc::UnboundedSender>); + +impl ThreadEventStream { + fn send_title_update(&self, text: SharedString) { + self.0 + .unbounded_send(Ok(ThreadEvent::TitleUpdate(text))) + .ok(); + } + + fn send_user_message(&self, message: &UserMessage) { + self.0 + .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone()))) + .ok(); + } -impl AgentResponseEventStream { fn send_text(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string()))) .ok(); } fn send_thinking(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string()))) .ok(); } @@ -1524,7 +1896,7 @@ impl AgentResponseEventStream { input: serde_json::Value, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call( + .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call( id, title.to_string(), kind, @@ -1557,7 +1929,7 @@ impl AgentResponseEventStream { fields: acp::ToolCallUpdateFields, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp::ToolCallUpdate { id: acp::ToolCallId(tool_use_id.to_string().into()), fields, @@ -1568,26 +1940,24 @@ impl AgentResponseEventStream { } fn send_retry(&self, status: acp_thread::RetryStatus) { - self.0 - .unbounded_send(Ok(AgentResponseEvent::Retry(status))) - .ok(); + self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok(); } fn send_stop(&self, reason: StopReason) { match reason { StopReason::EndTurn => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn))) .ok(); } StopReason::MaxTokens => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens))) .ok(); } StopReason::Refusal => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal))) .ok(); } StopReason::ToolUse => {} @@ -1596,7 +1966,7 @@ impl AgentResponseEventStream { fn send_canceled(&self) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled))) .ok(); } @@ -1608,24 +1978,23 @@ impl AgentResponseEventStream { #[derive(Clone)] pub struct ToolCallEventStream { tool_use_id: LanguageModelToolUseId, - stream: AgentResponseEventStream, + stream: ThreadEventStream, fs: Option>, } impl ToolCallEventStream { #[cfg(test)] pub fn test() -> (Self, ToolCallEventStreamReceiver) { - let (events_tx, events_rx) = mpsc::unbounded::>(); + let (events_tx, events_rx) = mpsc::unbounded::>(); - let stream = - ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None); + let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None); (stream, ToolCallEventStreamReceiver(events_rx)) } fn new( tool_use_id: LanguageModelToolUseId, - stream: AgentResponseEventStream, + stream: ThreadEventStream, fs: Option>, ) -> Self { Self { @@ -1643,7 +2012,7 @@ impl ToolCallEventStream { pub fn update_diff(&self, diff: Entity) { self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp_thread::ToolCallUpdateDiff { id: acp::ToolCallId(self.tool_use_id.to_string().into()), diff, @@ -1656,7 +2025,7 @@ impl ToolCallEventStream { pub fn update_terminal(&self, terminal: Entity) { self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp_thread::ToolCallUpdateTerminal { id: acp::ToolCallId(self.tool_use_id.to_string().into()), terminal, @@ -1674,7 +2043,7 @@ impl ToolCallEventStream { let (response_tx, response_rx) = oneshot::channel(); self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( + .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization( ToolCallAuthorization { tool_call: acp::ToolCallUpdate { id: acp::ToolCallId(self.tool_use_id.to_string().into()), @@ -1724,13 +2093,13 @@ impl ToolCallEventStream { } #[cfg(test)] -pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); +pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); #[cfg(test)] impl ToolCallEventStreamReceiver { pub async fn expect_authorization(&mut self) -> ToolCallAuthorization { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event { + if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event { auth } else { panic!("Expected ToolCallAuthorization but got: {:?}", event); @@ -1739,9 +2108,9 @@ impl ToolCallEventStreamReceiver { pub async fn expect_terminal(&mut self) -> Entity { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallUpdate( - acp_thread::ToolCallUpdate::UpdateTerminal(update), - ))) = event + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal( + update, + )))) = event { update.terminal } else { @@ -1752,7 +2121,7 @@ impl ToolCallEventStreamReceiver { #[cfg(test)] impl std::ops::Deref for ToolCallEventStreamReceiver { - type Target = mpsc::UnboundedReceiver>; + type Target = mpsc::UnboundedReceiver>; fn deref(&self) -> &Self::Target { &self.0 @@ -1821,6 +2190,38 @@ impl From for UserMessageContent { } } +impl From for acp::ContentBlock { + fn from(content: UserMessageContent) -> Self { + match content { + UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent { + data: image.source.to_string(), + mime_type: "image/png".to_string(), + annotations: None, + uri: None, + }), + UserMessageContent::Mention { uri, content } => { + acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: uri.to_uri().to_string(), + name: uri.name(), + annotations: None, + description: if content.is_empty() { + None + } else { + Some(content) + }, + mime_type: None, + size: None, + title: None, + }) + } + } + } +} + fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { LanguageModelImage { source: image_content.data.into(), diff --git a/crates/agent2/src/tools/context_server_registry.rs b/crates/agent2/src/tools/context_server_registry.rs index ddeb08a046..69c4221a81 100644 --- a/crates/agent2/src/tools/context_server_registry.rs +++ b/crates/agent2/src/tools/context_server_registry.rs @@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool { }) }) } + + fn replay( + &self, + _input: serde_json::Value, + _output: serde_json::Value, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 7687d68702..b3b1a428bf 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use cloud_llm_client::CompletionIntent; use collections::HashSet; -use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use indoc::formatdoc; -use language::ToPoint; use language::language_settings::{self, FormatOnSave}; +use language::{LanguageRegistry, ToPoint}; use language_model::LanguageModelToolResultContent; use paths; use project::lsp_store::{FormatTrigger, LspFormatTarget}; @@ -98,11 +98,13 @@ pub enum EditFileMode { #[derive(Debug, Serialize, Deserialize)] pub struct EditFileToolOutput { + #[serde(alias = "original_path")] input_path: PathBuf, - project_path: PathBuf, new_text: String, old_text: Arc, + #[serde(default)] diff: String, + #[serde(alias = "raw_output")] edit_agent_output: EditAgentOutput, } @@ -122,12 +124,16 @@ impl From for LanguageModelToolResultContent { } pub struct EditFileTool { - thread: Entity, + thread: WeakEntity, + language_registry: Arc, } impl EditFileTool { - pub fn new(thread: Entity) -> Self { - Self { thread } + pub fn new(thread: WeakEntity, language_registry: Arc) -> Self { + Self { + thread, + language_registry, + } } fn authorize( @@ -167,8 +173,11 @@ impl EditFileTool { // Check if path is inside the global config directory // First check if it's already inside project - if not, try to canonicalize - let thread = self.thread.read(cx); - let project_path = thread.project().read(cx).find_project_path(&input.path, cx); + let Ok(project_path) = self.thread.read_with(cx, |thread, cx| { + thread.project().read(cx).find_project_path(&input.path, cx) + }) else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; // If the path is inside the project, and it's not one of the above edge cases, // then no confirmation is necessary. Otherwise, confirmation is necessary. @@ -221,7 +230,12 @@ impl AgentTool for EditFileTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let project = self.thread.read(cx).project().clone(); + let Ok(project) = self + .thread + .read_with(cx, |thread, _cx| thread.project().clone()) + else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; let project_path = match resolve_path(&input, project.clone(), cx) { Ok(path) => path, Err(err) => return Task::ready(Err(anyhow!(err))), @@ -237,23 +251,17 @@ impl AgentTool for EditFileTool { }); } - let Some(request) = self.thread.update(cx, |thread, cx| { - thread - .build_completion_request(CompletionIntent::ToolResults, cx) - .ok() - }) else { - return Task::ready(Err(anyhow!("Failed to build completion request"))); - }; - let thread = self.thread.read(cx); - let Some(model) = thread.model().cloned() else { - return Task::ready(Err(anyhow!("No language model configured"))); - }; - let action_log = thread.action_log().clone(); - let authorize = self.authorize(&input, &event_stream, cx); cx.spawn(async move |cx: &mut AsyncApp| { authorize.await?; + let (request, model, action_log) = self.thread.update(cx, |thread, cx| { + let request = thread.build_completion_request(CompletionIntent::ToolResults, cx); + (request, thread.model().cloned(), thread.action_log().clone()) + })?; + let request = request?; + let model = model.context("No language model configured")?; + let edit_format = EditFormat::from_model(model.clone())?; let edit_agent = EditAgent::new( model, @@ -419,7 +427,6 @@ impl AgentTool for EditFileTool { Ok(EditFileToolOutput { input_path: input.path, - project_path: project_path.path.to_path_buf(), new_text: new_text.clone(), old_text, diff: unified_diff, @@ -427,6 +434,25 @@ impl AgentTool for EditFileTool { }) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + event_stream.update_diff(cx.new(|cx| { + Diff::finalized( + output.input_path, + Some(output.old_text.to_string()), + output.new_text, + self.language_registry.clone(), + cx, + ) + })); + Ok(()) + } } /// Validate that the file path is valid, meaning: @@ -515,6 +541,7 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -527,6 +554,7 @@ mod tests { action_log, Templates::new(), Some(model), + None, cx, ) }); @@ -537,7 +565,11 @@ mod tests { path: "root/nonexistent_file.txt".into(), mode: EditFileMode::Edit, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( + input, + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!( @@ -724,6 +756,7 @@ mod tests { action_log.clone(), Templates::new(), Some(model.clone()), + None, cx, ) }); @@ -750,9 +783,10 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) + Arc::new(EditFileTool::new( + thread.downgrade(), + language_registry.clone(), + )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -806,7 +840,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -850,6 +888,7 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { @@ -860,6 +899,7 @@ mod tests { action_log.clone(), Templates::new(), Some(model.clone()), + None, cx, ) }); @@ -887,9 +927,10 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) + Arc::new(EditFileTool::new( + thread.downgrade(), + language_registry.clone(), + )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -938,10 +979,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) - .run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -976,6 +1018,7 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { @@ -986,10 +1029,11 @@ mod tests { action_log.clone(), Templates::new(), Some(model.clone()), + None, cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); fs.insert_tree("/root", json!({})).await; // Test 1: Path with .zed component should require confirmation @@ -1111,6 +1155,7 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -1123,10 +1168,11 @@ mod tests { action_log.clone(), Templates::new(), Some(model.clone()), + None, cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test global config paths - these should require confirmation if they exist and are outside the project let test_cases = vec![ @@ -1220,7 +1266,7 @@ mod tests { cx, ) .await; - + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1233,10 +1279,11 @@ mod tests { action_log.clone(), Templates::new(), Some(model.clone()), + None, cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test files in different worktrees let test_cases = vec![ @@ -1302,6 +1349,7 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1314,10 +1362,11 @@ mod tests { action_log.clone(), Templates::new(), Some(model.clone()), + None, cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test edge cases let test_cases = vec![ @@ -1386,6 +1435,7 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1398,10 +1448,11 @@ mod tests { action_log.clone(), Templates::new(), Some(model.clone()), + None, cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test different EditFileMode values let modes = vec![ @@ -1467,6 +1518,7 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1479,10 +1531,11 @@ mod tests { action_log.clone(), Templates::new(), Some(model.clone()), + None, cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); assert_eq!( tool.initial_title(Err(json!({ diff --git a/crates/agent2/src/tools/terminal_tool.rs b/crates/agent2/src/tools/terminal_tool.rs index ac79874c36..1804d0ab30 100644 --- a/crates/agent2/src/tools/terminal_tool.rs +++ b/crates/agent2/src/tools/terminal_tool.rs @@ -319,7 +319,7 @@ mod tests { use theme::ThemeSettings; use util::test::TempTree; - use crate::AgentResponseEvent; + use crate::ThreadEvent; use super::*; @@ -396,7 +396,7 @@ mod tests { }); cx.run_until_parked(); let event = stream_rx.try_next(); - if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event { + if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event { auth.response.send(auth.options[0].id.clone()).unwrap(); } diff --git a/crates/agent2/src/tools/web_search_tool.rs b/crates/agent2/src/tools/web_search_tool.rs index c1c0970742..d71a128bfe 100644 --- a/crates/agent2/src/tools/web_search_tool.rs +++ b/crates/agent2/src/tools/web_search_tool.rs @@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool { } }; - let result_text = if response.results.len() == 1 { - "1 result".to_string() - } else { - format!("{} results", response.results.len()) - }; - event_stream.update_fields(acp::ToolCallUpdateFields { - title: Some(format!("Searched the web: {result_text}")), - content: Some( - response - .results - .iter() - .map(|result| acp::ToolCallContent::Content { - content: acp::ContentBlock::ResourceLink(acp::ResourceLink { - name: result.title.clone(), - uri: result.url.clone(), - title: Some(result.title.clone()), - description: Some(result.text.clone()), - mime_type: None, - annotations: None, - size: None, - }), - }) - .collect(), - ), - ..Default::default() - }); + emit_update(&response, &event_stream); Ok(WebSearchToolOutput(response)) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + emit_update(&output.0, &event_stream); + Ok(()) + } +} + +fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) { + let result_text = if response.results.len() == 1 { + "1 result".to_string() + } else { + format!("{} results", response.results.len()) + }; + event_stream.update_fields(acp::ToolCallUpdateFields { + title: Some(format!("Searched the web: {result_text}")), + content: Some( + response + .results + .iter() + .map(|result| acp::ToolCallContent::Content { + content: acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: result.title.clone(), + uri: result.url.clone(), + title: Some(result.title.clone()), + description: Some(result.text.clone()), + mime_type: None, + annotations: None, + size: None, + }), + }) + .collect(), + ), + ..Default::default() + }); } diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 886f650470..cbc874057a 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -18,6 +18,7 @@ doctest = false [dependencies] acp_thread.workspace = true +action_log.workspace = true agent-client-protocol.workspace = true agent_settings.workspace = true agentic-coding-protocol.workspace = true diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 551e9fa01a..aa80f01c15 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -1,4 +1,5 @@ // Translates old acp agents into the new schema +use action_log::ActionLog; use agent_client_protocol as acp; use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; use anyhow::{Context as _, Result, anyhow}; @@ -443,7 +444,8 @@ impl AgentConnection for AcpConnection { cx.update(|cx| { let thread = cx.new(|cx| { let session_id = acp::SessionId("acp-old-no-id".into()); - AcpThread::new(self.name, self.clone(), project, session_id, cx) + let action_log = cx.new(|_| ActionLog::new(project.clone())); + AcpThread::new(self.name, self.clone(), project, action_log, session_id) }); current_thread.replace(thread.downgrade()); thread diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index 93a5ae757a..d749537c4c 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -1,3 +1,4 @@ +use action_log::ActionLog; use agent_client_protocol::{self as acp, Agent as _}; use anyhow::anyhow; use collections::HashMap; @@ -153,14 +154,14 @@ impl AgentConnection for AcpConnection { })?; let session_id = response.session_id; - - let thread = cx.new(|cx| { + let action_log = cx.new(|_| ActionLog::new(project.clone()))?; + let thread = cx.new(|_cx| { AcpThread::new( self.server_name, self.clone(), project, + action_log, session_id.clone(), - cx, ) })?; diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 34d55f39dc..f27c973ad6 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -1,6 +1,7 @@ mod mcp_server; pub mod tools; +use action_log::ActionLog; use collections::HashMap; use context_server::listener::McpServerTool; use language_models::provider::anthropic::AnthropicLanguageModelProvider; @@ -215,8 +216,15 @@ impl AgentConnection for ClaudeAgentConnection { } }); - let thread = cx.new(|cx| { - AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) + let action_log = cx.new(|_| ActionLog::new(project.clone()))?; + let thread = cx.new(|_cx| { + AcpThread::new( + "Claude Code", + self.clone(), + project, + action_log, + session_id.clone(), + ) })?; thread_tx.send(thread.downgrade())?; diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index ad0920bc4a..150f1ea73b 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -303,8 +303,13 @@ impl AcpThreadView { let action_log_subscription = cx.observe(&action_log, |_, _, cx| cx.notify()); - this.list_state - .splice(0..0, thread.read(cx).entries().len()); + let count = thread.read(cx).entries().len(); + this.list_state.splice(0..0, count); + this.entry_view_state.update(cx, |view_state, cx| { + for ix in 0..count { + view_state.sync_entry(ix, &thread, window, cx); + } + }); AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); @@ -808,6 +813,7 @@ impl AcpThreadView { self.thread_retry_status.take(); self.thread_state = ThreadState::ServerExited { status: *status }; } + AcpThreadEvent::TitleUpdated => {} } cx.notify(); } @@ -2816,12 +2822,15 @@ impl AcpThreadView { return; }; - thread.update(cx, |thread, _cx| { + thread.update(cx, |thread, cx| { let current_mode = thread.completion_mode(); - thread.set_completion_mode(match current_mode { - CompletionMode::Burn => CompletionMode::Normal, - CompletionMode::Normal => CompletionMode::Burn, - }); + thread.set_completion_mode( + match current_mode { + CompletionMode::Burn => CompletionMode::Normal, + CompletionMode::Normal => CompletionMode::Burn, + }, + cx, + ); }); } @@ -3572,8 +3581,9 @@ impl AcpThreadView { )) .on_click({ cx.listener(move |this, _, _window, cx| { - thread.update(cx, |thread, _cx| { - thread.set_completion_mode(CompletionMode::Burn); + thread.update(cx, |thread, cx| { + thread + .set_completion_mode(CompletionMode::Burn, cx); }); this.resume_chat(cx); }) @@ -4156,12 +4166,13 @@ pub(crate) mod tests { cx: &mut gpui::App, ) -> Task>> { Task::ready(Ok(cx.new(|cx| { + let action_log = cx.new(|_| ActionLog::new(project.clone())); AcpThread::new( "SaboteurAgentConnection", self, project, + action_log, SessionId("test".into()), - cx, ) }))) } diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index f474fdf3ae..5b4f1038e2 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -199,24 +199,21 @@ impl AgentDiffPane { let action_log = thread.action_log(cx).clone(); let mut this = Self { - _subscriptions: [ - Some( - cx.observe_in(&action_log, window, |this, _action_log, window, cx| { - this.update_excerpts(window, cx) - }), - ), + _subscriptions: vec![ + cx.observe_in(&action_log, window, |this, _action_log, window, cx| { + this.update_excerpts(window, cx) + }), match &thread { - AgentDiffThread::Native(thread) => { - Some(cx.subscribe(thread, |this, _thread, event, cx| { - this.handle_thread_event(event, cx) - })) - } - AgentDiffThread::AcpThread(_) => None, + AgentDiffThread::Native(thread) => cx + .subscribe(thread, |this, _thread, event, cx| { + this.handle_native_thread_event(event, cx) + }), + AgentDiffThread::AcpThread(thread) => cx + .subscribe(thread, |this, _thread, event, cx| { + this.handle_acp_thread_event(event, cx) + }), }, - ] - .into_iter() - .flatten() - .collect(), + ], title: SharedString::default(), multibuffer, editor, @@ -324,13 +321,20 @@ impl AgentDiffPane { } } - fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { + fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { match event { ThreadEvent::SummaryGenerated => self.update_title(cx), _ => {} } } + fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context) { + match event { + AcpThreadEvent::TitleUpdated => self.update_title(cx), + _ => {} + } + } + pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) { if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) { self.editor.update(cx, |editor, cx| { @@ -1523,7 +1527,8 @@ impl AgentDiff { AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => { self.update_reviewing_editors(workspace, window, cx); } - AcpThreadEvent::EntriesRemoved(_) + AcpThreadEvent::TitleUpdated + | AcpThreadEvent::EntriesRemoved(_) | AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::Retry(_) => {} }