diff --git a/Cargo.lock b/Cargo.lock index 3158a61ad8..f5ea0e1f8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,7 @@ dependencies = [ "agent-client-protocol", "anyhow", "buffer_diff", + "chrono", "collections", "editor", "env_logger 0.11.8", @@ -191,10 +192,12 @@ version = "0.1.0" dependencies = [ "acp_thread", "action_log", + "agent", "agent-client-protocol", "agent_servers", "agent_settings", "anyhow", + "assistant_context", "assistant_tool", "assistant_tools", "chrono", @@ -208,6 +211,7 @@ dependencies = [ "env_logger 0.11.8", "fs", "futures 0.3.31", + "git", "gpui", "gpui_tokio", "handlebars 4.5.0", @@ -221,6 +225,7 @@ dependencies = [ "log", "lsp", "open", + "parking_lot", "paths", "portable-pty", "pretty_assertions", @@ -233,6 +238,7 @@ dependencies = [ "serde_json", "settings", "smol", + "sqlez", "task", "tempfile", "terminal", @@ -249,6 +255,7 @@ dependencies = [ "workspace-hack", "worktree", "zlog", + "zstd", ] [[package]] diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 2b9a6513c8..cbe74c1f37 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -21,6 +21,7 @@ agent-client-protocol.workspace = true agent.workspace = true anyhow.workspace = true buffer_diff.workspace = true +chrono.workspace = true collections.workspace = true editor.workspace = true file_icons.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index e104c40bf2..bb9c2e35ea 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -6,11 +6,13 @@ mod terminal; pub use connection::*; pub use diff::*; pub use mention::*; +use serde::{Deserialize, Serialize}; pub use terminal::*; use action_log::ActionLog; use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow}; +use chrono::{DateTime, Utc}; use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; @@ -537,9 +539,15 @@ impl ToolCallContent { acp::ToolCallContent::Content { content } => { Self::ContentBlock(ContentBlock::new(content, &language_registry, cx)) } - acp::ToolCallContent::Diff { diff } => { - Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx))) - } + acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| { + Diff::finalized( + diff.path, + diff.old_text, + diff.new_text, + language_registry, + cx, + ) + })), } } @@ -658,6 +666,17 @@ impl PlanEntry { } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct AgentServerName(pub SharedString); + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcpThreadMetadata { + pub agent: AgentServerName, + pub id: acp::SessionId, + pub title: SharedString, + pub updated_at: DateTime, +} + pub struct AcpThread { title: SharedString, entries: Vec, @@ -673,6 +692,7 @@ pub struct AcpThread { #[derive(Debug)] pub enum AcpThreadEvent { NewEntry, + TitleUpdated, EntryUpdated(usize), EntriesRemoved(Range), ToolAuthorizationRequired, @@ -916,6 +936,12 @@ impl AcpThread { cx.emit(AcpThreadEvent::NewEntry); } + pub fn update_title(&mut self, title: SharedString, cx: &mut Context) -> Result<()> { + self.title = title; + cx.emit(AcpThreadEvent::TitleUpdated); + Ok(()) + } + pub fn update_tool_call( &mut self, update: impl Into, @@ -1641,7 +1667,7 @@ mod tests { use super::*; use anyhow::anyhow; use futures::{channel::mpsc, future::LocalBoxFuture, select}; - use gpui::{AsyncApp, TestAppContext, WeakEntity}; + use gpui::{App, AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; use project::{FakeFs, Fs}; use rand::Rng as _; @@ -2311,7 +2337,7 @@ mod tests { self: Rc, project: Entity, _cwd: &Path, - cx: &mut gpui::App, + cx: &mut App, ) -> Task>> { let session_id = acp::SessionId( rand::thread_rng() diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index a328499bbc..af653a1c74 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,14 +1,16 @@ -use crate::AcpThread; +use crate::{AcpThread, AcpThreadMetadata}; use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; +use futures::channel::mpsc::UnboundedReceiver; use gpui::{Entity, SharedString, Task}; use project::Project; +use serde::{Deserialize, Serialize}; use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; use uuid::Uuid; -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct UserMessageId(Arc); impl UserMessageId { @@ -62,6 +64,10 @@ pub trait AgentConnection { None } + fn history(self: Rc) -> Option> { + None + } + fn into_any(self: Rc) -> Rc; } @@ -79,6 +85,18 @@ pub trait AgentSessionResume { fn run(&self, cx: &mut App) -> Task>; } +pub trait AgentHistory { + fn list_threads(&self, cx: &mut App) -> Task>>; + fn observe_history(&self, cx: &mut App) -> UnboundedReceiver; + fn load_thread( + self: Rc, + _project: Entity, + _cwd: &Path, + _session_id: acp::SessionId, + _cx: &mut App, + ) -> Task>>; +} + #[derive(Debug)] pub struct AuthRequired; diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs index a2c2d6c322..a67e37bcb8 100644 --- a/crates/acp_thread/src/diff.rs +++ b/crates/acp_thread/src/diff.rs @@ -1,4 +1,3 @@ -use agent_client_protocol as acp; use anyhow::Result; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::{MultiBuffer, PathKey}; @@ -21,17 +20,13 @@ pub enum Diff { } impl Diff { - pub fn from_acp( - diff: acp::Diff, + pub fn finalized( + path: PathBuf, + old_text: Option, + new_text: String, language_registry: Arc, cx: &mut Context, ) -> Self { - let acp::Diff { - path, - old_text, - new_text, - } = diff; - let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); diff --git a/crates/acp_thread/src/mention.rs b/crates/acp_thread/src/mention.rs index b9b021c4ca..6686e9ebf1 100644 --- a/crates/acp_thread/src/mention.rs +++ b/crates/acp_thread/src/mention.rs @@ -2,6 +2,7 @@ use agent::ThreadId; use anyhow::{Context as _, Result, bail}; use file_icons::FileIcons; use prompt_store::{PromptId, UserPromptId}; +use serde::{Deserialize, Serialize}; use std::{ fmt, ops::Range, @@ -11,7 +12,7 @@ use std::{ use ui::{App, IconName, SharedString}; use url::Url; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum MentionUri { File { abs_path: PathBuf, diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index eb39c3e454..4f2668384f 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -62,7 +62,7 @@ enum SerializedRecentOpen { pub struct HistoryStore { thread_store: Entity, - context_store: Entity, + pub context_store: Entity, recently_opened_entries: VecDeque, _subscriptions: Vec, _save_recently_opened_entries_task: Task<()>, diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 12c94a522d..e24a5ec782 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -893,7 +893,7 @@ impl ThreadsDatabase { let needs_migration_from_heed = mdb_path.exists(); - let connection = if *ZED_STATELESS { + let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) { Connection::open_memory(Some("THREAD_FALLBACK_DB")) } else { Connection::open_file(&sqlite_path.to_string_lossy()) diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index ac1840e5e5..a32b4fe939 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -17,6 +17,7 @@ action_log.workspace = true agent-client-protocol.workspace = true agent_servers.workspace = true agent_settings.workspace = true +agent.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true @@ -26,6 +27,7 @@ collections.workspace = true context_server.workspace = true fs.workspace = true futures.workspace = true +git.workspace = true gpui.workspace = true handlebars = { workspace = true, features = ["rust-embed"] } html_to_markdown.workspace = true @@ -37,6 +39,7 @@ language_model.workspace = true language_models.workspace = true log.workspace = true open.workspace = true +parking_lot.workspace = true paths.workspace = true portable-pty.workspace = true project.workspace = true @@ -46,6 +49,7 @@ schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +sqlez.workspace = true smol.workspace = true task.workspace = true terminal.workspace = true @@ -57,8 +61,12 @@ watch.workspace = true web_search.workspace = true which.workspace = true workspace-hack.workspace = true +zstd.workspace = true +assistant_context.workspace = true [dev-dependencies] +agent = { workspace = true, "features" = ["test-support"] } +acp_thread = { workspace = true, "features" = ["test-support"] } ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } @@ -66,6 +74,7 @@ context_server = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } +git = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } gpui_tokio.workspace = true language = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index af740d9901..cc3a40f652 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,10 +1,12 @@ +use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME; use crate::{ - AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, - DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, - MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, - ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, + ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool, + EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, + OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, + UserMessageContent, WebSearchTool, templates::Templates, }; -use acp_thread::AgentModelSelector; +use crate::{ThreadsDatabase, generate_session_id}; +use acp_thread::{AcpThread, AcpThreadMetadata, AgentHistory, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; @@ -15,7 +17,7 @@ use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; -use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry}; +use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel}; use project::{Project, ProjectItem, ProjectPath, Worktree}; use prompt_store::{ ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, @@ -27,6 +29,7 @@ use std::collections::HashMap; use std::path::Path; use std::rc::Rc; use std::sync::Arc; +use std::time::Duration; use util::ResultExt; const RULES_FILE_NAMES: [&'static str; 9] = [ @@ -41,6 +44,8 @@ const RULES_FILE_NAMES: [&'static str; 9] = [ "GEMINI.md", ]; +const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500); + pub struct RulesLoadingError { pub message: SharedString, } @@ -51,7 +56,8 @@ struct Session { thread: Entity, /// The ACP thread that handles protocol communication acp_thread: WeakEntity, - _subscription: Subscription, + save_task: Task<()>, + _subscriptions: Vec, } pub struct LanguageModels { @@ -166,6 +172,8 @@ pub struct NativeAgent { models: LanguageModels, project: Entity, prompt_store: Option>, + thread_database: Arc, + history_watchers: Vec>, fs: Arc, _subscriptions: Vec, } @@ -184,6 +192,11 @@ impl NativeAgent { .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))? .await; + let thread_database = cx + .update(|cx| ThreadsDatabase::connect(cx))? + .await + .map_err(|e| anyhow!(e))?; + cx.new(|cx| { let mut subscriptions = vec![ cx.subscribe(&project, Self::handle_project_event), @@ -208,16 +221,87 @@ impl NativeAgent { context_server_registry: cx.new(|cx| { ContextServerRegistry::new(project.read(cx).context_server_store(), cx) }), + thread_database, templates, models: LanguageModels::new(cx), project, prompt_store, fs, + history_watchers: Vec::new(), _subscriptions: subscriptions, } }) } + pub fn insert_session( + &mut self, + thread: Entity, + acp_thread: Entity, + cx: &mut Context, + ) { + let id = thread.read(cx).id().clone(); + let weak_thread = acp_thread.downgrade(); + self.sessions.insert( + id, + Session { + thread: thread.clone(), + acp_thread: weak_thread.clone(), + save_task: Task::ready(()), + _subscriptions: vec![ + cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }), + cx.observe(&thread, move |this, thread, cx| { + if let Some(response_stream) = + thread.update(cx, |thread, cx| thread.generate_title_if_needed(cx)) + { + NativeAgentConnection::handle_thread_events( + response_stream, + weak_thread.clone(), + cx, + ) + .detach_and_log_err(cx); + } + this.save_thread(thread.clone(), cx) + }), + ], + }, + ); + } + + fn save_thread(&mut self, thread_handle: Entity, cx: &mut Context) { + let thread = thread_handle.read(cx); + let id = thread.id().clone(); + let Some(session) = self.sessions.get_mut(&id) else { + return; + }; + + let thread = thread_handle.downgrade(); + let thread_database = self.thread_database.clone(); + session.save_task = cx.spawn(async move |this, cx| { + cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; + + let Some(task) = thread.update(cx, |thread, cx| thread.to_db(cx)).ok() else { + return; + }; + let db_thread = task.await; + let metadata = thread_database + .save_thread(id.clone(), db_thread) + .await + .log_err(); + if let Some(metadata) = metadata { + this.update(cx, |this, _| { + for watcher in this.history_watchers.iter_mut() { + watcher + .unbounded_send(metadata.clone().to_acp(NATIVE_AGENT_SERVER_NAME)) + .log_err(); + } + }) + .ok(); + } + }); + } + pub fn models(&self) -> &LanguageModels { &self.models } @@ -420,7 +504,7 @@ impl NativeAgent { fn handle_models_updated_event( &mut self, - _registry: Entity, + registry: Entity, _event: &language_model::Event, cx: &mut Context, ) { @@ -435,9 +519,14 @@ impl NativeAgent { if thread.model().is_none() && let Some(model) = default_model.clone() { - thread.set_model(model); + thread.set_model(model, cx); cx.notify(); } + let summarization_model = registry + .read(cx) + .thread_summary_model() + .map(|model| model.model.clone()); + thread.set_summarization_model(summarization_model, cx); }); } } @@ -461,10 +550,7 @@ impl NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, f: impl 'static - + FnOnce( - Entity, - &mut App, - ) -> Result>>, + + FnOnce(Entity, &mut App) -> Result>>, ) -> Task> { let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { agent @@ -476,10 +562,18 @@ impl NativeAgentConnection { }; log::debug!("Found session for: {}", session_id); - let mut response_stream = match f(thread, cx) { + let response_stream = match f(thread, cx) { Ok(stream) => stream, Err(err) => return Task::ready(Err(err)), }; + Self::handle_thread_events(response_stream, acp_thread, cx) + } + + fn handle_thread_events( + mut response_stream: mpsc::UnboundedReceiver>, + acp_thread: WeakEntity, + cx: &mut App, + ) -> Task> { cx.spawn(async move |cx| { // Handle response stream and forward to session.acp_thread while let Some(result) = response_stream.next().await { @@ -488,7 +582,18 @@ impl NativeAgentConnection { log::trace!("Received completion event: {:?}", event); match event { - AgentResponseEvent::Text(text) => { + ThreadEvent::UserMessage(message) => { + acp_thread.update(cx, |thread, cx| { + for content in message.content { + thread.push_user_content_block( + Some(message.id.clone()), + content.into(), + cx, + ); + } + })?; + } + ThreadEvent::AgentText(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block( acp::ContentBlock::Text(acp::TextContent { @@ -500,7 +605,7 @@ impl NativeAgentConnection { ) })?; } - AgentResponseEvent::Thinking(text) => { + ThreadEvent::AgentThinking(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block( acp::ContentBlock::Text(acp::TextContent { @@ -512,7 +617,7 @@ impl NativeAgentConnection { ) })?; } - AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { + ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { tool_call, options, response, @@ -535,17 +640,21 @@ impl NativeAgentConnection { }) .detach(); } - AgentResponseEvent::ToolCall(tool_call) => { + ThreadEvent::ToolCall(tool_call) => { acp_thread.update(cx, |thread, cx| { thread.upsert_tool_call(tool_call, cx) })??; } - AgentResponseEvent::ToolCallUpdate(update) => { + ThreadEvent::ToolCallUpdate(update) => { acp_thread.update(cx, |thread, cx| { thread.update_tool_call(update, cx) })??; } - AgentResponseEvent::Stop(stop_reason) => { + ThreadEvent::TitleUpdate(title) => { + acp_thread + .update(cx, |thread, cx| thread.update_title(title, cx))??; + } + ThreadEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse { stop_reason }); } @@ -564,6 +673,31 @@ impl NativeAgentConnection { }) }) } + + fn register_tools( + thread: &mut Thread, + project: Entity, + action_log: Entity, + cx: &mut Context, + ) { + let language_registry = project.read(cx).languages().clone(); + thread.add_tool(CopyPathTool::new(project.clone())); + thread.add_tool(CreateDirectoryTool::new(project.clone())); + thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); + thread.add_tool(DiagnosticsTool::new(project.clone())); + thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry)); + thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); + thread.add_tool(FindPathTool::new(project.clone())); + thread.add_tool(GrepTool::new(project.clone())); + thread.add_tool(ListDirectoryTool::new(project.clone())); + thread.add_tool(MovePathTool::new(project.clone())); + thread.add_tool(NowTool); + thread.add_tool(OpenTool::new(project.clone())); + thread.add_tool(ReadFileTool::new(project.clone(), action_log)); + thread.add_tool(TerminalTool::new(project.clone(), cx)); + thread.add_tool(ThinkingTool); + thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. + } } impl AgentModelSelector for NativeAgentConnection { @@ -598,8 +732,8 @@ impl AgentModelSelector for NativeAgentConnection { return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); }; - thread.update(cx, |thread, _cx| { - thread.set_model(model.clone()); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); }); update_settings_file::( @@ -660,7 +794,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::debug!("Starting thread creation in async context"); // Generate session ID - let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + let session_id = generate_session_id(); log::info!("Created session with ID: {}", session_id); // Create AcpThread @@ -694,32 +828,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .model_from_id(&LanguageModels::model_id(&default_model.model)) }); + let summarization_model = registry.thread_summary_model().map(|c| c.model); + let thread = cx.new(|cx| { let mut thread = Thread::new( + session_id.clone(), project.clone(), agent.project_context.clone(), agent.context_server_registry.clone(), action_log.clone(), agent.templates.clone(), default_model, + summarization_model, cx, ); - thread.add_tool(CopyPathTool::new(project.clone())); - thread.add_tool(CreateDirectoryTool::new(project.clone())); - thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); - thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(EditFileTool::new(cx.entity())); - thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); - thread.add_tool(FindPathTool::new(project.clone())); - thread.add_tool(GrepTool::new(project.clone())); - thread.add_tool(ListDirectoryTool::new(project.clone())); - thread.add_tool(MovePathTool::new(project.clone())); - thread.add_tool(NowTool); - thread.add_tool(OpenTool::new(project.clone())); - thread.add_tool(ReadFileTool::new(project.clone(), action_log)); - thread.add_tool(TerminalTool::new(project.clone(), cx)); - thread.add_tool(ThinkingTool); - thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. + Self::register_tools(&mut thread, project, action_log, cx); thread }); @@ -729,16 +852,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Store the session agent.update(cx, |agent, cx| { - agent.sessions.insert( - session_id, - Session { - thread, - acp_thread: acp_thread.downgrade(), - _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { - this.sessions.remove(acp_thread.session_id()); - }), - }, - ); + agent.insert_session(thread, acp_thread.clone(), cx) })?; Ok(acp_thread) @@ -797,7 +911,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::info!("Cancelling on session: {}", session_id); self.0.update(cx, |agent, cx| { if let Some(agent) = agent.sessions.get(session_id) { - agent.thread.update(cx, |thread, _cx| thread.cancel()); + agent.thread.update(cx, |thread, cx| thread.cancel(cx)); } }); } @@ -815,6 +929,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { }) } + fn history(self: Rc) -> Option> { + Some(self) + } + fn into_any(self: Rc) -> Rc { self } @@ -824,7 +942,121 @@ struct NativeAgentSessionEditor(Entity); impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { - Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) + Task::ready( + self.0 + .update(cx, |thread, cx| thread.truncate(message_id, cx)), + ) + } +} + +impl acp_thread::AgentHistory for NativeAgentConnection { + fn list_threads(&self, cx: &mut App) -> Task>> { + let database = self.0.read(cx).thread_database.clone(); + cx.background_executor().spawn(async move { + let threads = database.list_threads().await?; + anyhow::Ok( + threads + .into_iter() + .map(|thread| thread.to_acp(NATIVE_AGENT_SERVER_NAME)) + .collect::>(), + ) + }) + } + + fn observe_history(&self, cx: &mut App) -> mpsc::UnboundedReceiver { + let (tx, rx) = mpsc::unbounded(); + self.0.update(cx, |this, _| this.history_watchers.push(tx)); + rx + } + + fn load_thread( + self: Rc, + project: Entity, + _cwd: &Path, + session_id: acp::SessionId, + cx: &mut App, + ) -> Task>> { + let database = self.0.update(cx, |this, _| this.thread_database.clone()); + cx.spawn(async move |cx| { + let db_thread = database + .load_thread(session_id.clone()) + .await? + .context("no such thread found")?; + + let acp_thread = cx.update(|cx| { + cx.new(|cx| { + acp_thread::AcpThread::new( + db_thread.title.clone(), + self.clone(), + project.clone(), + session_id.clone(), + cx, + ) + }) + })?; + let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; + let agent = self.0.clone(); + + // Create Thread + let thread = agent.update(cx, |agent, cx| { + let language_model_registry = LanguageModelRegistry::global(cx); + let configured_model = language_model_registry + .update(cx, |registry, cx| { + db_thread + .model + .as_ref() + .and_then(|model| { + let model = SelectedModel { + provider: model.provider.clone().into(), + model: model.model.clone().into(), + }; + registry.select_model(&model, cx) + }) + .or_else(|| registry.default_model()) + }) + .context("no default model configured")?; + + let model = agent + .models + .model_from_id(&LanguageModels::model_id(&configured_model.model)) + .context("no model by id")?; + + let summarization_model = language_model_registry + .read(cx) + .thread_summary_model() + .map(|c| c.model); + + let thread = cx.new(|cx| { + let mut thread = Thread::from_db( + session_id, + db_thread, + project.clone(), + agent.project_context.clone(), + agent.context_server_registry.clone(), + action_log.clone(), + agent.templates.clone(), + model, + summarization_model, + cx, + ); + Self::register_tools(&mut thread, project, action_log, cx); + thread + }); + + anyhow::Ok(thread) + })??; + + // Store the session + agent.update(cx, |agent, cx| { + agent.insert_session(thread.clone(), acp_thread.clone(), cx) + })?; + + let events = thread.update(cx, |thread, cx| thread.replay(cx))?; + cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))? + .await?; + + Ok(acp_thread) + }) } } @@ -844,12 +1076,16 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume { #[cfg(test)] mod tests { + use crate::HistoryStore; + use super::*; use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; use fs::FakeFs; use gpui::TestAppContext; + use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; + use util::path; #[gpui::test] async fn test_maintaining_project_context(cx: &mut TestAppContext) { @@ -1024,6 +1260,80 @@ mod tests { ); } + #[gpui::test] + async fn test_history(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + + let agent = NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + let history = connection.clone().history().unwrap(); + let history_store = cx.new(|cx| HistoryStore::get_or_init(cx)); + + history_store + .update(cx, |history_store, cx| { + history_store.load_history(NATIVE_AGENT_SERVER_NAME.clone(), history.as_ref(), cx) + }) + .await + .unwrap(); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_thread(project.clone(), Path::new(path!("")), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let selector = connection.model_selector().unwrap(); + + let summarization_model: Arc = + Arc::new(FakeLanguageModel::default()) as _; + + agent.update(cx, |agent, cx| { + let thread = agent.sessions.get(&session_id).unwrap().thread.clone(); + thread.update(cx, |thread, cx| { + thread.set_summarization_model(Some(summarization_model.clone()), cx); + }) + }); + + let model = cx + .update(|cx| selector.selected_model(&session_id, cx)) + .await + .expect("selected_model should succeed"); + let model = cx + .update(|cx| agent.read(cx).models().model_from_id(&model.id)) + .unwrap(); + let model = model.as_fake(); + + let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Hi", cx)); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("Hey"); + model.end_last_completion_stream(); + send.await.unwrap(); + + summarization_model + .as_fake() + .send_last_completion_stream_text_chunk("Saying Hello"); + summarization_model.as_fake().end_last_completion_stream(); + cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE); + + let history = history_store.update(cx, |store, cx| store.entries(cx)); + assert_eq!(history.len(), 1); + assert_eq!(history[0].title(), "Saying Hello"); + } + fn init_test(cx: &mut TestAppContext) { env_logger::try_init().ok(); cx.update(|cx| { diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index f13cd1bd67..6d1d266ada 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,4 +1,6 @@ mod agent; +mod db; +mod history_store; mod native_agent_server; mod templates; mod thread; @@ -8,7 +10,15 @@ mod tools; mod tests; pub use agent::*; +pub use db::*; +pub use history_store::*; pub use native_agent_server::NativeAgentServer; pub use templates::*; pub use thread::*; pub use tools::*; + +use agent_client_protocol as acp; + +pub fn generate_session_id() -> acp::SessionId { + acp::SessionId(uuid::Uuid::new_v4().to_string().into()) +} diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs new file mode 100644 index 0000000000..afc4fdcb3f --- /dev/null +++ b/crates/agent2/src/db.rs @@ -0,0 +1,488 @@ +use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; +use acp_thread::{AcpThreadMetadata, AgentServerName}; +use agent::thread_store; +use agent_client_protocol as acp; +use agent_settings::{AgentProfileId, CompletionMode}; +use anyhow::{Result, anyhow}; +use chrono::{DateTime, Utc}; +use collections::{HashMap, IndexMap}; +use futures::{FutureExt, future::Shared}; +use gpui::{BackgroundExecutor, Global, Task}; +use indoc::indoc; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use sqlez::{ + bindable::{Bind, Column}, + connection::Connection, + statement::Statement, +}; +use std::sync::Arc; +use ui::{App, SharedString}; + +pub type DbMessage = crate::Message; +pub type DbSummary = agent::thread::DetailedSummaryState; +pub type DbLanguageModel = thread_store::SerializedLanguageModel; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbThreadMetadata { + pub id: acp::SessionId, + #[serde(alias = "summary")] + pub title: SharedString, + pub updated_at: DateTime, +} + +impl DbThreadMetadata { + pub fn to_acp(self, agent: AgentServerName) -> AcpThreadMetadata { + AcpThreadMetadata { + agent, + id: self.id, + title: self.title, + updated_at: self.updated_at, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DbThread { + pub title: SharedString, + pub messages: Vec, + pub updated_at: DateTime, + #[serde(default)] + pub summary: DbSummary, + #[serde(default)] + pub initial_project_snapshot: Option>, + #[serde(default)] + pub cumulative_token_usage: language_model::TokenUsage, + #[serde(default)] + pub request_token_usage: Vec, + #[serde(default)] + pub model: Option, + #[serde(default)] + pub completion_mode: Option, + #[serde(default)] + pub profile: Option, +} + +impl DbThread { + pub const VERSION: &'static str = "0.3.0"; + + pub fn from_json(json: &[u8]) -> Result { + let saved_thread_json = serde_json::from_slice::(json)?; + match saved_thread_json.get("version") { + Some(serde_json::Value::String(version)) => match version.as_str() { + Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?), + _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?), + }, + _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?), + } + } + + fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result { + let mut messages = Vec::new(); + for msg in thread.messages { + let message = match msg.role { + language_model::Role::User => { + let mut content = Vec::new(); + + // Convert segments to content + for segment in msg.segments { + match segment { + thread_store::SerializedMessageSegment::Text { text } => { + content.push(UserMessageContent::Text(text)); + } + thread_store::SerializedMessageSegment::Thinking { text, .. } => { + // User messages don't have thinking segments, but handle gracefully + content.push(UserMessageContent::Text(text)); + } + thread_store::SerializedMessageSegment::RedactedThinking { .. } => { + // User messages don't have redacted thinking, skip. + } + } + } + + // If no content was added, add context as text if available + if content.is_empty() && !msg.context.is_empty() { + content.push(UserMessageContent::Text(msg.context)); + } + + crate::Message::User(UserMessage { + // MessageId from old format can't be meaningfully converted, so generate a new one + id: acp_thread::UserMessageId::new(), + content, + }) + } + language_model::Role::Assistant => { + let mut content = Vec::new(); + + // Convert segments to content + for segment in msg.segments { + match segment { + thread_store::SerializedMessageSegment::Text { text } => { + content.push(AgentMessageContent::Text(text)); + } + thread_store::SerializedMessageSegment::Thinking { + text, + signature, + } => { + content.push(AgentMessageContent::Thinking { text, signature }); + } + thread_store::SerializedMessageSegment::RedactedThinking { data } => { + content.push(AgentMessageContent::RedactedThinking(data)); + } + } + } + + // Convert tool uses + let mut tool_names_by_id = HashMap::default(); + for tool_use in msg.tool_uses { + tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone()); + content.push(AgentMessageContent::ToolUse( + language_model::LanguageModelToolUse { + id: tool_use.id, + name: tool_use.name.into(), + raw_input: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + input: tool_use.input, + is_input_complete: true, + }, + )); + } + + // Convert tool results + let mut tool_results = IndexMap::default(); + for tool_result in msg.tool_results { + let name = tool_names_by_id + .remove(&tool_result.tool_use_id) + .unwrap_or_else(|| SharedString::from("unknown")); + tool_results.insert( + tool_result.tool_use_id.clone(), + language_model::LanguageModelToolResult { + tool_use_id: tool_result.tool_use_id, + tool_name: name.into(), + is_error: tool_result.is_error, + content: tool_result.content, + output: tool_result.output, + }, + ); + } + + crate::Message::Agent(AgentMessage { + content, + tool_results, + }) + } + language_model::Role::System => { + // Skip system messages as they're not supported in the new format + continue; + } + }; + + messages.push(message); + } + + Ok(Self { + title: thread.summary, + messages, + updated_at: thread.updated_at, + summary: thread.detailed_summary_state, + initial_project_snapshot: thread.initial_project_snapshot, + cumulative_token_usage: thread.cumulative_token_usage, + request_token_usage: thread.request_token_usage, + model: thread.model, + completion_mode: thread.completion_mode, + profile: thread.profile, + }) + } +} + +pub static ZED_STATELESS: std::sync::LazyLock = + std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DataType { + #[serde(rename = "json")] + Json, + #[serde(rename = "zstd")] + Zstd, +} + +impl Bind for DataType { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let value = match self { + DataType::Json => "json", + DataType::Zstd => "zstd", + }; + value.bind(statement, start_index) + } +} + +impl Column for DataType { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (value, next_index) = String::column(statement, start_index)?; + let data_type = match value.as_str() { + "json" => DataType::Json, + "zstd" => DataType::Zstd, + _ => anyhow::bail!("Unknown data type: {}", value), + }; + Ok((data_type, next_index)) + } +} + +pub(crate) struct ThreadsDatabase { + executor: BackgroundExecutor, + connection: Arc>, +} + +struct GlobalThreadsDatabase(Shared, Arc>>>); + +impl Global for GlobalThreadsDatabase {} + +impl ThreadsDatabase { + fn connection(&self) -> Arc> { + self.connection.clone() + } + + const COMPRESSION_LEVEL: i32 = 3; +} + +impl ThreadsDatabase { + pub fn connect(cx: &mut App) -> Shared, Arc>>> { + if cx.has_global::() { + return cx.global::().0.clone(); + } + let executor = cx.background_executor().clone(); + let task = executor + .spawn({ + let executor = executor.clone(); + async move { + match ThreadsDatabase::new(executor) { + Ok(db) => Ok(Arc::new(db)), + Err(err) => Err(Arc::new(err)), + } + } + }) + .shared(); + + cx.set_global(GlobalThreadsDatabase(task.clone())); + task + } + + pub fn new(executor: BackgroundExecutor) -> Result { + let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) { + Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else { + let threads_dir = paths::data_dir().join("threads"); + std::fs::create_dir_all(&threads_dir)?; + let sqlite_path = threads_dir.join("threads.db"); + Connection::open_file(&sqlite_path.to_string_lossy()) + }; + + connection.exec(indoc! {" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + summary TEXT NOT NULL, + updated_at TEXT NOT NULL, + data_type TEXT NOT NULL, + data BLOB NOT NULL + ) + "})?() + .map_err(|e| anyhow!("Failed to create threads table: {}", e))?; + + let db = Self { + executor: executor.clone(), + connection: Arc::new(Mutex::new(connection)), + }; + + Ok(db) + } + + fn save_thread_sync( + connection: &Arc>, + id: acp::SessionId, + thread: DbThread, + ) -> Result { + let json_data = serde_json::to_string(&thread)?; + let title = thread.title.to_string(); + let updated_at = thread.updated_at.to_rfc3339(); + + let connection = connection.lock(); + + let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?; + let data_type = DataType::Zstd; + let data = compressed; + + let mut insert = connection.exec_bound::<(Arc, String, String, DataType, Vec)>(indoc! {" + INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?) + "})?; + + insert((id.0.clone(), title, updated_at, data_type, data))?; + + Ok(DbThreadMetadata { + id, + title: thread.title, + updated_at: thread.updated_at, + }) + } + + pub fn list_threads(&self) -> Task>> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + let mut select = + connection.select_bound::<(), (Arc, String, String)>(indoc! {" + SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC + "})?; + + let rows = select(())?; + let mut threads = Vec::new(); + + for (id, summary, updated_at) in rows { + threads.push(DbThreadMetadata { + id: acp::SessionId(id), + title: summary.into(), + updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), + }); + } + + Ok(threads) + }) + } + + pub fn load_thread(&self, id: acp::SessionId) -> Task>> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + let mut select = connection.select_bound::, (DataType, Vec)>(indoc! {" + SELECT data_type, data FROM threads WHERE id = ? LIMIT 1 + "})?; + + let rows = select(id.0)?; + if let Some((data_type, data)) = rows.into_iter().next() { + let json_data = match data_type { + DataType::Zstd => { + let decompressed = zstd::decode_all(&data[..])?; + String::from_utf8(decompressed)? + } + DataType::Json => String::from_utf8(data)?, + }; + dbg!(&json_data); + + let thread = dbg!(DbThread::from_json(json_data.as_bytes()))?; + Ok(Some(thread)) + } else { + Ok(None) + } + }) + } + + pub fn save_thread( + &self, + id: acp::SessionId, + thread: DbThread, + ) -> Task> { + let connection = self.connection.clone(); + + self.executor + .spawn(async move { Self::save_thread_sync(&connection, id, thread) }) + } + + pub fn delete_thread(&self, id: acp::SessionId) -> Task> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + + let mut delete = connection.exec_bound::>(indoc! {" + DELETE FROM threads WHERE id = ? + "})?; + + delete(id.0)?; + + Ok(()) + }) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use agent::MessageSegment; + use agent::context::LoadedContext; + use client::Client; + use fs::FakeFs; + use gpui::AppContext; + use gpui::TestAppContext; + use http_client::FakeHttpClient; + use language_model::Role; + use project::Project; + use settings::SettingsStore; + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + + let http_client = FakeHttpClient::with_404_response(); + let clock = Arc::new(clock::FakeSystemClock::new()); + let client = Client::new(clock, http_client, cx); + agent::init(cx); + agent_settings::init(cx); + language_model::init(client.clone(), cx); + }); + } + + #[gpui::test] + async fn test_retrieving_old_thread(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + + // Save a thread using the old agent. + { + let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx)); + let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx)); + thread.update(cx, |thread, cx| { + thread.insert_message( + Role::User, + vec![MessageSegment::Text("Hey!".into())], + LoadedContext::default(), + vec![], + false, + cx, + ); + thread.insert_message( + Role::Assistant, + vec![MessageSegment::Text("How're you doing?".into())], + LoadedContext::default(), + vec![], + false, + cx, + ) + }); + thread_store + .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) + .await + .unwrap(); + } + + let db = cx.update(|cx| ThreadsDatabase::connect(cx)).await.unwrap(); + let threads = db.list_threads().await.unwrap(); + assert_eq!(threads.len(), 1); + let thread = db + .load_thread(threads[0].id.clone()) + .await + .unwrap() + .unwrap(); + assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n"); + assert_eq!( + thread.messages[1].to_markdown(), + "## Assistant\n\nHow're you doing?\n" + ); + } +} diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs new file mode 100644 index 0000000000..996702bff7 --- /dev/null +++ b/crates/agent2/src/history_store.rs @@ -0,0 +1,174 @@ +use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName}; +use agent_client_protocol as acp; +use agent_servers::AgentServer; +use assistant_context::SavedContextMetadata; +use chrono::{DateTime, Utc}; +use collections::HashMap; +use gpui::{Entity, Global, SharedString, Task, prelude::*}; +use project::Project; +use serde::{Deserialize, Serialize}; +use ui::App; + +use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; + +use crate::NativeAgentServer; + +const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; +const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json"; +const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50); + +// todo!(put this in the UI) +#[derive(Clone, Debug)] +pub enum HistoryEntry { + AcpThread(AcpThreadMetadata), + TextThread(SavedContextMetadata), +} + +impl HistoryEntry { + pub fn updated_at(&self) -> DateTime { + match self { + HistoryEntry::AcpThread(thread) => thread.updated_at, + HistoryEntry::TextThread(context) => context.mtime.to_utc(), + } + } + + pub fn id(&self) -> HistoryEntryId { + match self { + HistoryEntry::AcpThread(thread) => { + HistoryEntryId::Thread(thread.agent.clone(), thread.id.clone()) + } + HistoryEntry::TextThread(context) => HistoryEntryId::Context(context.path.clone()), + } + } + + pub fn title(&self) -> &SharedString { + match self { + HistoryEntry::AcpThread(thread) => &thread.title, + HistoryEntry::TextThread(context) => &context.title, + } + } +} + +/// Generic identifier for a history entry. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum HistoryEntryId { + Thread(AgentServerName, acp::SessionId), + Context(Arc), +} + +#[derive(Serialize, Deserialize)] +enum SerializedRecentOpen { + Thread(String), + ContextName(String), + /// Old format which stores the full path + Context(String), +} + +#[derive(Default)] +pub struct AgentHistory { + entries: HashMap, + loaded: bool, +} + +pub struct HistoryStore { + agents: HashMap, // todo!() text threads +} +// note, we have to share the history store between all windows +// because we only get updates from one connection at a time. +struct GlobalHistoryStore(Entity); +impl Global for GlobalHistoryStore {} + +impl HistoryStore { + pub fn get_or_init(project: &Entity, cx: &mut App) -> Entity { + if cx.has_global::() { + return cx.global::().0.clone(); + } + let history_store = cx.new(|cx| HistoryStore::new(cx)); + cx.set_global(GlobalHistoryStore(history_store.clone())); + let root_dir = project + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).abs_path()) + .unwrap_or_else(|| paths::home_dir().as_path().into()); + + let agent = NativeAgentServer::new(project.read(cx).fs().clone()); + let connect = agent.connect(&root_dir, project, cx); + cx.spawn({ + let history_store = history_store.clone(); + async move |cx| { + let connection = connect.await?.history().unwrap(); + history_store + .update(cx, |history_store, cx| { + history_store.load_history(agent.name(), connection.as_ref(), cx) + })? + .await + } + }) + .detach_and_log_err(cx); + history_store + } + + fn new(_cx: &mut Context) -> Self { + Self { + agents: HashMap::default(), + } + } + + pub fn update_history(&mut self, entry: AcpThreadMetadata, cx: &mut Context) { + let agent = self + .agents + .entry(entry.agent.clone()) + .or_insert(Default::default()); + + agent.entries.insert(entry.id.clone(), entry); + cx.notify() + } + + pub fn load_history( + &mut self, + agent_name: AgentServerName, + connection: &dyn acp_thread::AgentHistory, + cx: &mut Context, + ) -> Task> { + let threads = connection.list_threads(cx); + cx.spawn(async move |this, cx| { + let threads = threads.await?; + + this.update(cx, |this, cx| { + this.agents.insert( + agent_name, + AgentHistory { + loaded: true, + entries: threads.into_iter().map(|t| (t.id.clone(), t)).collect(), + }, + ); + cx.notify() + }) + }) + } + + pub fn entries(&mut self, _cx: &mut Context) -> Vec { + let mut history_entries = Vec::new(); + + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() { + return history_entries; + } + + history_entries.extend( + self.agents + .values_mut() + .flat_map(|history| history.entries.values().cloned()) // todo!("surface the loading state?") + .map(HistoryEntry::AcpThread), + ); + // todo!() include the text threads in here. + + history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at())); + history_entries + } + + pub fn recent_entries(&mut self, limit: usize, cx: &mut Context) -> Vec { + self.entries(cx).into_iter().take(limit).collect() + } +} diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs index cadd88a846..c8ff38e893 100644 --- a/crates/agent2/src/native_agent_server.rs +++ b/crates/agent2/src/native_agent_server.rs @@ -1,11 +1,13 @@ use std::{path::Path, rc::Rc, sync::Arc}; +use acp_thread::AgentServerName; use agent_servers::AgentServer; use anyhow::Result; use fs::Fs; use gpui::{App, Entity, Task}; use project::Project; use prompt_store::PromptStore; +use ui::SharedString; use crate::{NativeAgent, NativeAgentConnection, templates::Templates}; @@ -20,9 +22,12 @@ impl NativeAgentServer { } } +pub const NATIVE_AGENT_SERVER_NAME: AgentServerName = + AgentServerName(SharedString::new_static("Native Agent")); + impl AgentServer for NativeAgentServer { - fn name(&self) -> &'static str { - "Native Agent" + fn name(&self) -> AgentServerName { + NATIVE_AGENT_SERVER_NAME.clone() } fn empty_state_headline(&self) -> &'static str { diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index e3e3050d49..9aac27dcd6 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -343,7 +343,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { let mut saw_partial_tool_use = false; while let Some(event) = events.next().await { - if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { + if let Ok(ThreadEvent::ToolCall(tool_call)) = event { thread.update(cx, |thread, _cx| { // Look for a tool use in the thread's last message let message = thread.last_message().unwrap(); @@ -733,16 +733,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { ); } -async fn expect_tool_call( - events: &mut UnboundedReceiver>, -) -> acp::ToolCall { +async fn expect_tool_call(events: &mut UnboundedReceiver>) -> acp::ToolCall { let event = events .next() .await .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCall(tool_call) => return tool_call, + ThreadEvent::ToolCall(tool_call) => return tool_call, event => { panic!("Unexpected event {event:?}"); } @@ -750,7 +748,7 @@ async fn expect_tool_call( } async fn expect_tool_call_update_fields( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCallUpdate { let event = events .next() @@ -758,7 +756,7 @@ async fn expect_tool_call_update_fields( .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { return update; } event => { @@ -768,7 +766,7 @@ async fn expect_tool_call_update_fields( } async fn next_tool_call_authorization( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> ToolCallAuthorization { loop { let event = events @@ -776,7 +774,7 @@ async fn next_tool_call_authorization( .await .expect("no tool call authorization event received") .unwrap(); - if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { + if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event { let permission_kinds = tool_call_authorization .options .iter() @@ -943,13 +941,13 @@ async fn test_cancellation(cx: &mut TestAppContext) { let mut echo_completed = false; while let Some(event) = events.next().await { match event.unwrap() { - AgentResponseEvent::ToolCall(tool_call) => { + ThreadEvent::ToolCall(tool_call) => { assert_eq!(tool_call.title, expected_tools.remove(0)); if tool_call.title == "Echo" { echo_id = Some(tool_call.id); } } - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( acp::ToolCallUpdate { id, fields: @@ -971,13 +969,13 @@ async fn test_cancellation(cx: &mut TestAppContext) { // Cancel the current send and ensure that the event stream is closed, even // if one of the tools is still running. - thread.update(cx, |thread, _cx| thread.cancel()); + thread.update(cx, |thread, cx| thread.cancel(cx)); let events = events.collect::>().await; let last_event = events.last(); assert!( matches!( last_event, - Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled))) ), "unexpected event {last_event:?}" ); @@ -1159,7 +1157,7 @@ async fn test_truncate(cx: &mut TestAppContext) { }); thread - .update(cx, |thread, _cx| thread.truncate(message_id)) + .update(cx, |thread, cx| thread.truncate(message_id, cx)) .unwrap(); cx.run_until_parked(); thread.read_with(cx, |thread, _| { @@ -1434,11 +1432,11 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { } /// Filters out the stop events for asserting against in tests -fn stop_events(result_events: Vec>) -> Vec { +fn stop_events(result_events: Vec>) -> Vec { result_events .into_iter() .filter_map(|event| match event.unwrap() { - AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), + ThreadEvent::Stop(stop_reason) => Some(stop_reason), _ => None, }) .collect() @@ -1549,12 +1547,14 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, project_context.clone(), context_server_registry, action_log, templates, Some(model.clone()), + None, cx, ) }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 429832010b..b3c62a3a64 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,25 +1,35 @@ -use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; +use crate::{ + ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates, +}; use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; +use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot}; use agent_client_protocol as acp; -use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; +use chrono::{DateTime, Utc}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus}; use collections::IndexMap; use fs::Fs; use futures::{ + FutureExt, channel::{mpsc, oneshot}, + future::Shared, stream::FuturesUnordered, }; -use gpui::{App, Context, Entity, SharedString, Task}; +use git::repository::DiffType; +use gpui::{App, AppContext, Context, Entity, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, TokenUsage, +}; +use project::{ + Project, + git_store::{GitStore, RepositoryState}, }; -use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; @@ -30,28 +40,7 @@ use std::{fmt::Write, ops::Range}; use util::{ResultExt, markdown::MarkdownCodeBlock}; use uuid::Uuid; -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, -)] -pub struct ThreadId(Arc); - -impl ThreadId { - pub fn new() -> Self { - Self(Uuid::new_v4().to_string().into()) - } -} - -impl std::fmt::Display for ThreadId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From<&str> for ThreadId { - fn from(value: &str) -> Self { - Self(value.into()) - } -} +const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; /// The ID of the user prompt that initiated a request. /// @@ -71,7 +60,7 @@ impl std::fmt::Display for PromptId { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Message { User(UserMessage), Agent(AgentMessage), @@ -86,6 +75,18 @@ impl Message { } } + pub fn to_request(&self) -> Vec { + match self { + Message::User(message) => vec![message.to_request()], + Message::Agent(message) => message.to_request(), + Message::Resume => vec![LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: false, + }], + } + } + pub fn to_markdown(&self) -> String { match self { Message::User(message) => message.to_markdown(), @@ -93,15 +94,22 @@ impl Message { Message::Resume => "[resumed after tool use limit was reached]".into(), } } + + pub fn role(&self) -> Role { + match self { + Message::User(_) | Message::Resume => Role::User, + Message::Agent(_) => Role::Assistant, + } + } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UserMessage { pub id: UserMessageId, pub content: Vec, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum UserMessageContent { Text(String), Mention { uri: MentionUri, content: String }, @@ -313,9 +321,6 @@ impl AgentMessage { AgentMessageContent::RedactedThinking(_) => { markdown.push_str("\n") } - AgentMessageContent::Image(_) => { - markdown.push_str("\n"); - } AgentMessageContent::ToolUse(tool_use) => { markdown.push_str(&format!( "**Tool Use**: {} (ID: {})\n", @@ -386,9 +391,6 @@ impl AgentMessage { AgentMessageContent::ToolUse(value) => { language_model::MessageContent::ToolUse(value.clone()) } - AgentMessageContent::Image(value) => { - language_model::MessageContent::Image(value.clone()) - } }; assistant_message.content.push(chunk); } @@ -418,13 +420,13 @@ impl AgentMessage { } } -#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AgentMessage { pub content: Vec, pub tool_results: IndexMap, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum AgentMessageContent { Text(String), Thinking { @@ -432,17 +434,18 @@ pub enum AgentMessageContent { signature: Option, }, RedactedThinking(String), - Image(LanguageModelImage), ToolUse(LanguageModelToolUse), } #[derive(Debug)] -pub enum AgentResponseEvent { - Text(String), - Thinking(String), +pub enum ThreadEvent { + UserMessage(UserMessage), + AgentText(String), + AgentThinking(String), ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + TitleUpdate(SharedString), Stop(acp::StopReason), } @@ -453,9 +456,28 @@ pub struct ToolCallAuthorization { pub response: oneshot::Sender, } +enum ThreadTitle { + None, + Pending(Shared>), + Done(Result), +} + +impl ThreadTitle { + pub fn unwrap_or_default(&self) -> SharedString { + if let ThreadTitle::Done(Ok(title)) = self { + title.clone() + } else { + "New Thread".into() + } + } +} + pub struct Thread { - id: ThreadId, + id: acp::SessionId, prompt_id: PromptId, + updated_at: DateTime, + title: ThreadTitle, + summary: DetailedSummaryState, messages: Vec, completion_mode: CompletionMode, /// Holds the task that handles agent interaction until the end of the turn. @@ -465,45 +487,359 @@ pub struct Thread { pending_message: Option, tools: BTreeMap>, tool_use_limit_reached: bool, + request_token_usage: Vec, + cumulative_token_usage: TokenUsage, + initial_project_snapshot: Shared>>>, context_server_registry: Entity, profile_id: AgentProfileId, project_context: Rc>, templates: Arc, model: Option>, + summarization_model: Option>, project: Entity, action_log: Entity, } impl Thread { pub fn new( + id: acp::SessionId, project: Entity, project_context: Rc>, context_server_registry: Entity, action_log: Entity, templates: Arc, model: Option>, + summarization_model: Option>, cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { - id: ThreadId::new(), + id, prompt_id: PromptId::new(), + updated_at: Utc::now(), + title: ThreadTitle::None, + summary: DetailedSummaryState::default(), messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, pending_message: None, tools: BTreeMap::default(), tool_use_limit_reached: false, + request_token_usage: Vec::new(), + cumulative_token_usage: TokenUsage::default(), + initial_project_snapshot: { + let project_snapshot = Self::project_snapshot(project.clone(), cx); + cx.foreground_executor() + .spawn(async move { Some(project_snapshot.await) }) + .shared() + }, context_server_registry, profile_id, project_context, templates, model, + summarization_model, project, action_log, } } + #[cfg(any(test, feature = "test-support"))] + pub fn test( + model: Arc, + project: Entity, + action_log: Entity, + cx: &mut Context, + ) -> Self { + use crate::generate_session_id; + + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + + Self::new( + generate_session_id(), + project, + Rc::default(), + context_server_registry, + action_log, + Templates::new(), + Some(model), + None, + cx, + ) + } + + pub fn id(&self) -> &acp::SessionId { + &self.id + } + + pub fn from_db( + id: acp::SessionId, + db_thread: DbThread, + project: Entity, + project_context: Rc>, + context_server_registry: Entity, + action_log: Entity, + templates: Arc, + model: Arc, + summarization_model: Option>, + cx: &mut Context, + ) -> Self { + let profile_id = db_thread + .profile + .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); + Self { + id, + prompt_id: PromptId::new(), + title: ThreadTitle::Done(Ok(db_thread.title.clone())), + summary: db_thread.summary, + messages: db_thread.messages, + completion_mode: CompletionMode::Normal, + running_turn: None, + pending_message: None, + tools: BTreeMap::default(), + tool_use_limit_reached: false, + request_token_usage: db_thread.request_token_usage.clone(), + cumulative_token_usage: db_thread.cumulative_token_usage.clone(), + initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(), + context_server_registry, + profile_id, + project_context, + templates, + model: Some(model), + summarization_model, + project, + action_log, + updated_at: db_thread.updated_at, + } + } + + pub fn to_db(&self, cx: &App) -> Task { + let initial_project_snapshot = self.initial_project_snapshot.clone(); + let mut thread = DbThread { + title: self.title.unwrap_or_default(), + messages: self.messages.clone(), + updated_at: self.updated_at.clone(), + summary: self.summary.clone(), + initial_project_snapshot: None, + cumulative_token_usage: self.cumulative_token_usage.clone(), + request_token_usage: self.request_token_usage.clone(), + model: self.model.as_ref().map(|model| DbLanguageModel { + provider: model.provider_id().to_string(), + model: model.name().0.to_string(), + }), + completion_mode: Some(self.completion_mode.into()), + profile: Some(self.profile_id.clone()), + }; + + cx.background_spawn(async move { + let initial_project_snapshot = initial_project_snapshot.await; + thread.initial_project_snapshot = initial_project_snapshot; + thread + }) + } + + pub fn replay( + &mut self, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let (tx, rx) = mpsc::unbounded(); + let stream = ThreadEventStream(tx); + for message in &self.messages { + match message { + Message::User(user_message) => stream.send_user_message(&user_message), + Message::Agent(assistant_message) => { + for content in &assistant_message.content { + match content { + AgentMessageContent::Text(text) => stream.send_text(text), + AgentMessageContent::Thinking { text, .. } => { + stream.send_thinking(text) + } + AgentMessageContent::RedactedThinking(_) => {} + AgentMessageContent::ToolUse(tool_use) => { + self.replay_tool_call( + tool_use, + assistant_message.tool_results.get(&tool_use.id), + &stream, + cx, + ); + } + } + } + } + Message::Resume => {} + } + } + rx + } + + fn replay_tool_call( + &self, + tool_use: &LanguageModelToolUse, + tool_result: Option<&LanguageModelToolResult>, + stream: &ThreadEventStream, + cx: &mut Context, + ) { + let Some(tool) = self.tools.get(tool_use.name.as_ref()) else { + stream + .0 + .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + title: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Failed, + content: Vec::new(), + locations: Vec::new(), + raw_input: Some(tool_use.input.clone()), + raw_output: None, + }))) + .ok(); + return; + }; + + let title = tool.initial_title(tool_use.input.clone()); + let kind = tool.kind(); + stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + + let output = tool_result + .as_ref() + .and_then(|result| result.output.clone()); + if let Some(output) = output.clone() { + let tool_event_stream = ToolCallEventStream::new( + tool_use.id.clone(), + stream.clone(), + Some(self.project.read(cx).fs().clone()), + ); + tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) + .log_err(); + } + + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + raw_output: output, + ..Default::default() + }, + ); + } + + /// Create a snapshot of the current project state including git information and unsaved buffers. + fn project_snapshot( + project: Entity, + cx: &mut Context, + ) -> Task> { + let git_store = project.read(cx).git_store().clone(); + let worktree_snapshots: Vec<_> = project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) + .collect(); + + cx.spawn(async move |_, cx| { + let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; + + let mut unsaved_buffers = Vec::new(); + cx.update(|app_cx| { + let buffer_store = project.read(app_cx).buffer_store(); + for buffer_handle in buffer_store.read(app_cx).buffers() { + let buffer = buffer_handle.read(app_cx); + if buffer.is_dirty() { + if let Some(file) = buffer.file() { + let path = file.path().to_string_lossy().to_string(); + unsaved_buffers.push(path); + } + } + } + }) + .ok(); + + Arc::new(ProjectSnapshot { + worktree_snapshots, + unsaved_buffer_paths: unsaved_buffers, + timestamp: Utc::now(), + }) + }) + } + + fn worktree_snapshot( + worktree: Entity, + git_store: Entity, + cx: &App, + ) -> Task { + cx.spawn(async move |cx| { + // Get worktree path and snapshot + let worktree_info = cx.update(|app_cx| { + let worktree = worktree.read(app_cx); + let path = worktree.abs_path().to_string_lossy().to_string(); + let snapshot = worktree.snapshot(); + (path, snapshot) + }); + + let Ok((worktree_path, _snapshot)) = worktree_info else { + return WorktreeSnapshot { + worktree_path: String::new(), + git_state: None, + }; + }; + + let git_state = git_store + .update(cx, |git_store, cx| { + git_store + .repositories() + .values() + .find(|repo| { + repo.read(cx) + .abs_path_to_repo_path(&worktree.read(cx).abs_path()) + .is_some() + }) + .cloned() + }) + .ok() + .flatten() + .map(|repo| { + repo.update(cx, |repo, _| { + let current_branch = + repo.branch.as_ref().map(|branch| branch.name().to_owned()); + repo.send_job(None, |state, _| async move { + let RepositoryState::Local { backend, .. } = state else { + return GitState { + remote_url: None, + head_sha: None, + current_branch, + diff: None, + }; + }; + + let remote_url = backend.remote_url("origin"); + let head_sha = backend.head_sha().await; + let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); + + GitState { + remote_url, + head_sha, + current_branch, + diff, + } + }) + }) + }); + + let git_state = match git_state { + Some(git_state) => match git_state.ok() { + Some(git_state) => git_state.await.ok(), + None => None, + }, + None => None, + }; + + WorktreeSnapshot { + worktree_path, + git_state, + } + }) + } + pub fn project(&self) -> &Entity { &self.project } @@ -516,16 +852,27 @@ impl Thread { self.model.as_ref() } - pub fn set_model(&mut self, model: Arc) { + pub fn set_model(&mut self, model: Arc, cx: &mut Context) { self.model = Some(model); + cx.notify() + } + + pub fn set_summarization_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + self.summarization_model = model; + cx.notify() } pub fn completion_mode(&self) -> CompletionMode { self.completion_mode } - pub fn set_completion_mode(&mut self, mode: CompletionMode) { + pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context) { self.completion_mode = mode; + cx.notify() } #[cfg(any(test, feature = "test-support"))] @@ -553,29 +900,29 @@ impl Thread { self.profile_id = profile_id; } - pub fn cancel(&mut self) { + pub fn cancel(&mut self, cx: &mut Context) { if let Some(running_turn) = self.running_turn.take() { running_turn.cancel(); } - self.flush_pending_message(); + self.flush_pending_message(cx); } - pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { - self.cancel(); + pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { + self.cancel(cx); let Some(position) = self.messages.iter().position( |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), ) else { return Err(anyhow!("Message not found")); }; self.messages.truncate(position); + cx.notify(); Ok(()) } pub fn resume( &mut self, cx: &mut Context, - ) -> Result>> { - anyhow::ensure!(self.model.is_some(), "Model not set"); + ) -> Result>> { anyhow::ensure!( self.tool_use_limit_reached, "can only resume after tool use limit is reached" @@ -596,7 +943,7 @@ impl Thread { id: UserMessageId, content: impl IntoIterator, cx: &mut Context, - ) -> Result>> + ) -> Result>> where T: Into, { @@ -619,15 +966,12 @@ impl Thread { fn run_turn( &mut self, cx: &mut Context, - ) -> Result>> { - self.cancel(); + ) -> Result>> { + self.cancel(cx); - let model = self - .model() - .cloned() - .context("No language model configured")?; - let (events_tx, events_rx) = mpsc::unbounded::>(); - let event_stream = AgentResponseEventStream(events_tx); + let model = self.model.clone().context("No language model configured")?; + let (events_tx, events_rx) = mpsc::unbounded::>(); + let event_stream = ThreadEventStream(events_tx); let message_ix = self.messages.len().saturating_sub(1); self.tool_use_limit_reached = false; self.running_turn = Some(RunningTurn { @@ -661,8 +1005,8 @@ impl Thread { LanguageModelCompletionEvent::Stop(reason) => { event_stream.send_stop(reason); if reason == StopReason::Refusal { - this.update(cx, |this, _cx| { - this.flush_pending_message(); + this.update(cx, |this, cx| { + this.flush_pending_message(cx); this.messages.truncate(message_ix); })?; return Ok(()); @@ -714,7 +1058,7 @@ impl Thread { log::info!("No tool uses found, completing turn"); return Ok(()); } else { - this.update(cx, |this, _| this.flush_pending_message())?; + this.update(cx, |this, cx| this.flush_pending_message(cx))?; completion_intent = CompletionIntent::ToolResults; } } @@ -728,8 +1072,8 @@ impl Thread { log::info!("Turn execution completed successfully"); } - this.update(cx, |this, _| { - this.flush_pending_message(); + this.update(cx, |this, cx| { + this.flush_pending_message(cx); this.running_turn.take(); }) .ok(); @@ -738,6 +1082,101 @@ impl Thread { Ok(events_rx) } + pub fn generate_title_if_needed( + &mut self, + cx: &mut Context, + ) -> Option>> { + if !matches!(self.title, ThreadTitle::None) { + return None; + } + + // todo!() copy logic from agent1 re: tool calls, etc.? + if self.messages.len() < 2 { + return None; + } + let Some(model) = self.summarization_model.clone() else { + return None; + }; + let (tx, rx) = mpsc::unbounded(); + + self.generate_title(model, ThreadEventStream(tx), cx); + Some(rx) + } + + fn generate_title( + &mut self, + model: Arc, + event_stream: ThreadEventStream, + cx: &mut Context, + ) { + let mut request = LanguageModelRequest { + intent: Some(CompletionIntent::ThreadSummarization), + temperature: AgentSettings::temperature_for_model(&model, cx), + ..Default::default() + }; + + for message in &self.messages { + request.messages.extend(message.to_request()); + } + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(SUMMARIZE_THREAD_PROMPT.into())], + cache: false, + }); + + let task = cx + .spawn(async move |this, cx| { + let result: anyhow::Result = async { + let mut messages = model.stream_completion(request, &cx).await?; + + let mut new_summary = String::new(); + while let Some(event) = messages.next().await { + let Ok(event) = event else { + continue; + }; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { .. }, + ) => { + // this.update(cx, |thread, cx| { + // thread.update_model_request_usage(amount as u32, limit, cx); + // })?; + // todo!()? not sure if this is the right place to do this. + continue; + } + _ => continue, + }; + + let mut lines = text.lines(); + new_summary.extend(lines.next()); + + // Stop if the LLM generated multiple lines. + if lines.next().is_some() { + break; + } + } + + anyhow::Ok(new_summary.into()) + } + .await; + + this.update(cx, |this, cx| { + if let Ok(title) = &result { + event_stream.send_title_update(title.clone()); + } + this.title = ThreadTitle::Done(result); + cx.notify(); + }) + .log_err(); + }) + .shared(); + + self.title = ThreadTitle::Pending(task.clone()); + cx.notify() + } + pub fn build_system_message(&self) -> LanguageModelRequestMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { @@ -761,7 +1200,7 @@ impl Thread { fn handle_streamed_completion_event( &mut self, event: LanguageModelCompletionEvent, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) -> Option> { log::trace!("Handling streamed completion event: {:?}", event); @@ -769,7 +1208,7 @@ impl Thread { match event { StartMessage { .. } => { - self.flush_pending_message(); + self.flush_pending_message(cx); self.pending_message = Some(AgentMessage::default()); } Text(new_text) => self.handle_text_event(new_text, event_stream, cx), @@ -803,7 +1242,7 @@ impl Thread { fn handle_text_event( &mut self, new_text: String, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { event_stream.send_text(&new_text); @@ -824,7 +1263,7 @@ impl Thread { &mut self, new_text: String, new_signature: Option, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { event_stream.send_thinking(&new_text); @@ -856,7 +1295,7 @@ impl Thread { fn handle_tool_use_event( &mut self, tool_use: LanguageModelToolUse, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) -> Option> { cx.notify(); @@ -978,7 +1417,7 @@ impl Thread { self.pending_message.get_or_insert_default() } - fn flush_pending_message(&mut self) { + fn flush_pending_message(&mut self, cx: &mut Context) { let Some(mut message) = self.pending_message.take() else { return; }; @@ -995,9 +1434,7 @@ impl Thread { tool_use_id: tool_use.id.clone(), tool_name: tool_use.name.clone(), is_error: true, - content: LanguageModelToolResultContent::Text( - "Tool canceled by user".into(), - ), + content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()), output: None, }, ); @@ -1005,6 +1442,7 @@ impl Thread { } self.messages.push(Message::Agent(message)); + cx.notify() } pub(crate) fn build_completion_request( @@ -1096,15 +1534,7 @@ impl Thread { ); let mut messages = vec![self.build_system_message()]; for message in &self.messages { - match message { - Message::User(message) => messages.push(message.to_request()), - Message::Agent(message) => messages.extend(message.to_request()), - Message::Resume => messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec!["Continue where you left off".into()], - cache: false, - }), - } + messages.extend(message.to_request()); } if let Some(message) = self.pending_message.as_ref() { @@ -1151,7 +1581,7 @@ struct RunningTurn { _task: Task<()>, /// The current event stream for the running turn. Used to report a final /// cancellation event if we cancel the turn. - event_stream: AgentResponseEventStream, + event_stream: ThreadEventStream, } impl RunningTurn { @@ -1204,6 +1634,17 @@ where cx: &mut App, ) -> Task>; + /// Emits events for a previous execution of the tool. + fn replay( + &self, + _input: Self::Input, + _output: Self::Output, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } + fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) } @@ -1231,6 +1672,13 @@ pub trait AnyAgentTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task>; + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()>; } impl AnyAgentTool for Erased> @@ -1282,21 +1730,45 @@ where }) }) } + + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + let input = serde_json::from_value(input)?; + let output = serde_json::from_value(output)?; + self.0.replay(input, output, event_stream, cx) + } } #[derive(Clone)] -struct AgentResponseEventStream(mpsc::UnboundedSender>); +struct ThreadEventStream(mpsc::UnboundedSender>); + +impl ThreadEventStream { + fn send_user_message(&self, message: &UserMessage) { + self.0 + .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone()))) + .ok(); + } -impl AgentResponseEventStream { fn send_text(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string()))) + .ok(); + } + + fn send_title_update(&self, text: SharedString) { + self.0 + .unbounded_send(Ok(ThreadEvent::TitleUpdate(text))) .ok(); } fn send_thinking(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string()))) .ok(); } @@ -1308,7 +1780,7 @@ impl AgentResponseEventStream { input: serde_json::Value, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call( + .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call( id, title.to_string(), kind, @@ -1341,7 +1813,7 @@ impl AgentResponseEventStream { fields: acp::ToolCallUpdateFields, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp::ToolCallUpdate { id: acp::ToolCallId(tool_use_id.to_string().into()), fields, @@ -1355,17 +1827,17 @@ impl AgentResponseEventStream { match reason { StopReason::EndTurn => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn))) .ok(); } StopReason::MaxTokens => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens))) .ok(); } StopReason::Refusal => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal))) .ok(); } StopReason::ToolUse => {} @@ -1374,7 +1846,7 @@ impl AgentResponseEventStream { fn send_canceled(&self) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled))) .ok(); } @@ -1386,24 +1858,23 @@ impl AgentResponseEventStream { #[derive(Clone)] pub struct ToolCallEventStream { tool_use_id: LanguageModelToolUseId, - stream: AgentResponseEventStream, + stream: ThreadEventStream, fs: Option>, } impl ToolCallEventStream { #[cfg(test)] pub fn test() -> (Self, ToolCallEventStreamReceiver) { - let (events_tx, events_rx) = mpsc::unbounded::>(); + let (events_tx, events_rx) = mpsc::unbounded::>(); - let stream = - ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None); + let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None); (stream, ToolCallEventStreamReceiver(events_rx)) } fn new( tool_use_id: LanguageModelToolUseId, - stream: AgentResponseEventStream, + stream: ThreadEventStream, fs: Option>, ) -> Self { Self { @@ -1421,7 +1892,7 @@ impl ToolCallEventStream { pub fn update_diff(&self, diff: Entity) { self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp_thread::ToolCallUpdateDiff { id: acp::ToolCallId(self.tool_use_id.to_string().into()), diff, @@ -1434,7 +1905,7 @@ impl ToolCallEventStream { pub fn update_terminal(&self, terminal: Entity) { self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp_thread::ToolCallUpdateTerminal { id: acp::ToolCallId(self.tool_use_id.to_string().into()), terminal, @@ -1452,7 +1923,7 @@ impl ToolCallEventStream { let (response_tx, response_rx) = oneshot::channel(); self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( + .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization( ToolCallAuthorization { tool_call: acp::ToolCallUpdate { id: acp::ToolCallId(self.tool_use_id.to_string().into()), @@ -1502,13 +1973,13 @@ impl ToolCallEventStream { } #[cfg(test)] -pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); +pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); #[cfg(test)] impl ToolCallEventStreamReceiver { pub async fn expect_authorization(&mut self) -> ToolCallAuthorization { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event { + if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event { auth } else { panic!("Expected ToolCallAuthorization but got: {:?}", event); @@ -1517,9 +1988,9 @@ impl ToolCallEventStreamReceiver { pub async fn expect_terminal(&mut self) -> Entity { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallUpdate( - acp_thread::ToolCallUpdate::UpdateTerminal(update), - ))) = event + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal( + update, + )))) = event { update.terminal } else { @@ -1530,7 +2001,7 @@ impl ToolCallEventStreamReceiver { #[cfg(test)] impl std::ops::Deref for ToolCallEventStreamReceiver { - type Target = mpsc::UnboundedReceiver>; + type Target = mpsc::UnboundedReceiver>; fn deref(&self) -> &Self::Target { &self.0 @@ -1599,6 +2070,26 @@ impl From for UserMessageContent { } } +impl From for acp::ContentBlock { + fn from(content: UserMessageContent) -> Self { + match content { + UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent { + data: image.source.to_string(), + mime_type: "image/png".to_string(), + annotations: None, + uri: None, + }), + UserMessageContent::Mention { .. } => { + todo!() + } + } + } +} + fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { LanguageModelImage { source: image_content.data.into(), diff --git a/crates/agent2/src/tools/context_server_registry.rs b/crates/agent2/src/tools/context_server_registry.rs index db39e9278c..3b5225317f 100644 --- a/crates/agent2/src/tools/context_server_registry.rs +++ b/crates/agent2/src/tools/context_server_registry.rs @@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool { }) }) } + + fn replay( + &self, + _input: serde_json::Value, + _output: serde_json::Value, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index c55e503d76..756698bf3f 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use cloud_llm_client::CompletionIntent; use collections::HashSet; -use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use indoc::formatdoc; -use language::ToPoint; use language::language_settings::{self, FormatOnSave}; +use language::{LanguageRegistry, ToPoint}; use language_model::LanguageModelToolResultContent; use paths; use project::lsp_store::{FormatTrigger, LspFormatTarget}; @@ -98,11 +98,13 @@ pub enum EditFileMode { #[derive(Debug, Serialize, Deserialize)] pub struct EditFileToolOutput { + #[serde(alias = "original_path")] input_path: PathBuf, - project_path: PathBuf, new_text: String, old_text: Arc, + #[serde(default)] diff: String, + #[serde(alias = "raw_output")] edit_agent_output: EditAgentOutput, } @@ -122,12 +124,16 @@ impl From for LanguageModelToolResultContent { } pub struct EditFileTool { - thread: Entity, + thread: WeakEntity, + language_registry: Arc, } impl EditFileTool { - pub fn new(thread: Entity) -> Self { - Self { thread } + pub fn new(thread: WeakEntity, language_registry: Arc) -> Self { + Self { + thread, + language_registry, + } } fn authorize( @@ -167,8 +173,11 @@ impl EditFileTool { // Check if path is inside the global config directory // First check if it's already inside project - if not, try to canonicalize - let thread = self.thread.read(cx); - let project_path = thread.project().read(cx).find_project_path(&input.path, cx); + let Ok(project_path) = self.thread.read_with(cx, |thread, cx| { + thread.project().read(cx).find_project_path(&input.path, cx) + }) else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; // If the path is inside the project, and it's not one of the above edge cases, // then no confirmation is necessary. Otherwise, confirmation is necessary. @@ -221,7 +230,12 @@ impl AgentTool for EditFileTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let project = self.thread.read(cx).project().clone(); + let Ok(project) = self + .thread + .read_with(cx, |thread, _cx| thread.project().clone()) + else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; let project_path = match resolve_path(&input, project.clone(), cx) { Ok(path) => path, Err(err) => return Task::ready(Err(anyhow!(err))), @@ -237,23 +251,17 @@ impl AgentTool for EditFileTool { }); } - let Some(request) = self.thread.update(cx, |thread, cx| { - thread - .build_completion_request(CompletionIntent::ToolResults, cx) - .ok() - }) else { - return Task::ready(Err(anyhow!("Failed to build completion request"))); - }; - let thread = self.thread.read(cx); - let Some(model) = thread.model().cloned() else { - return Task::ready(Err(anyhow!("No language model configured"))); - }; - let action_log = thread.action_log().clone(); - let authorize = self.authorize(&input, &event_stream, cx); cx.spawn(async move |cx: &mut AsyncApp| { authorize.await?; + let (request, model, action_log) = self.thread.update(cx, |thread, cx| { + let request = thread.build_completion_request(CompletionIntent::ToolResults, cx); + (request, thread.model().cloned(), thread.action_log().clone()) + })?; + let request = request?; + let model = model.context("No language model configured")?; + let edit_format = EditFormat::from_model(model.clone())?; let edit_agent = EditAgent::new( model, @@ -419,7 +427,6 @@ impl AgentTool for EditFileTool { Ok(EditFileToolOutput { input_path: input.path, - project_path: project_path.path.to_path_buf(), new_text: new_text.clone(), old_text, diff: unified_diff, @@ -427,6 +434,25 @@ impl AgentTool for EditFileTool { }) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + event_stream.update_diff(cx.new(|cx| { + Diff::finalized( + output.input_path, + Some(output.old_text.to_string()), + output.new_text, + self.language_registry.clone(), + cx, + ) + })); + Ok(()) + } } /// Validate that the file path is valid, meaning: @@ -497,7 +523,6 @@ fn resolve_path( #[cfg(test)] mod tests { use super::*; - use crate::{ContextServerRegistry, Templates}; use action_log::ActionLog; use client::TelemetrySettings; use fs::Fs; @@ -505,7 +530,6 @@ mod tests { use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; - use std::rc::Rc; use util::path; #[gpui::test] @@ -515,21 +539,10 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project, - Rc::default(), - context_server_registry, - action_log, - Templates::new(), - Some(model), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); let result = cx .update(|cx| { let input = EditFileToolInput { @@ -537,7 +550,11 @@ mod tests { path: "root/nonexistent_file.txt".into(), mode: EditFileMode::Edit, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( + input, + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!( @@ -713,20 +730,8 @@ mod tests { }); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log.clone(), cx)); // First, test with format_on_save enabled cx.update(|cx| { @@ -750,9 +755,10 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) + Arc::new(EditFileTool::new( + thread.downgrade(), + language_registry.clone(), + )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -806,7 +812,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -848,21 +858,10 @@ mod tests { .unwrap(); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log, cx)); // First, test with remove_trailing_whitespace_on_save enabled cx.update(|cx| { @@ -887,9 +886,10 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) + Arc::new(EditFileTool::new( + thread.downgrade(), + language_registry.clone(), + )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -938,10 +938,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) - .run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -974,22 +975,12 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(EditFileTool { thread }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); fs.insert_tree("/root", json!({})).await; // Test 1: Path with .zed component should require confirmation @@ -1111,22 +1102,12 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(EditFileTool { thread }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test global config paths - these should require confirmation if they exist and are outside the project let test_cases = vec![ @@ -1220,23 +1201,12 @@ mod tests { cx, ) .await; - + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - Rc::default(), - context_server_registry.clone(), - action_log.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(EditFileTool { thread }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test files in different worktrees let test_cases = vec![ @@ -1302,22 +1272,12 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - Rc::default(), - context_server_registry.clone(), - action_log.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(EditFileTool { thread }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test edge cases let test_cases = vec![ @@ -1386,22 +1346,12 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - Rc::default(), - context_server_registry.clone(), - action_log.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(EditFileTool { thread }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test different EditFileMode values let modes = vec![ @@ -1467,22 +1417,12 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - project.clone(), - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(EditFileTool { thread }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); assert_eq!( tool.initial_title(Err(json!({ diff --git a/crates/agent2/src/tools/terminal_tool.rs b/crates/agent2/src/tools/terminal_tool.rs index ecb855ac34..6984475d18 100644 --- a/crates/agent2/src/tools/terminal_tool.rs +++ b/crates/agent2/src/tools/terminal_tool.rs @@ -319,7 +319,7 @@ mod tests { use theme::ThemeSettings; use util::test::TempTree; - use crate::AgentResponseEvent; + use crate::ThreadEvent; use super::*; @@ -396,7 +396,7 @@ mod tests { }); cx.run_until_parked(); let event = stream_rx.try_next(); - if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event { + if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event { auth.response.send(auth.options[0].id.clone()).unwrap(); } diff --git a/crates/agent2/src/tools/web_search_tool.rs b/crates/agent2/src/tools/web_search_tool.rs index c1c0970742..d71a128bfe 100644 --- a/crates/agent2/src/tools/web_search_tool.rs +++ b/crates/agent2/src/tools/web_search_tool.rs @@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool { } }; - let result_text = if response.results.len() == 1 { - "1 result".to_string() - } else { - format!("{} results", response.results.len()) - }; - event_stream.update_fields(acp::ToolCallUpdateFields { - title: Some(format!("Searched the web: {result_text}")), - content: Some( - response - .results - .iter() - .map(|result| acp::ToolCallContent::Content { - content: acp::ContentBlock::ResourceLink(acp::ResourceLink { - name: result.title.clone(), - uri: result.url.clone(), - title: Some(result.title.clone()), - description: Some(result.text.clone()), - mime_type: None, - annotations: None, - size: None, - }), - }) - .collect(), - ), - ..Default::default() - }); + emit_update(&response, &event_stream); Ok(WebSearchToolOutput(response)) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + emit_update(&output.0, &event_stream); + Ok(()) + } +} + +fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) { + let result_text = if response.results.len() == 1 { + "1 result".to_string() + } else { + format!("{} results", response.results.len()) + }; + event_stream.update_fields(acp::ToolCallUpdateFields { + title: Some(format!("Searched the web: {result_text}")), + content: Some( + response + .results + .iter() + .map(|result| acp::ToolCallContent::Content { + content: acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: result.title.clone(), + uri: result.url.clone(), + title: Some(result.title.clone()), + description: Some(result.text.clone()), + mime_type: None, + annotations: None, + size: None, + }), + }) + .collect(), + ), + ..Default::default() + }); } diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 00e3e3df50..7891316925 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -1,7 +1,7 @@ use std::{path::Path, rc::Rc}; use crate::AgentServerCommand; -use acp_thread::AgentConnection; +use acp_thread::{AgentConnection, AgentServerName}; use anyhow::Result; use gpui::AsyncApp; use thiserror::Error; @@ -14,12 +14,12 @@ mod v1; pub struct UnsupportedVersion; pub async fn connect( - server_name: &'static str, + server_name: AgentServerName, command: AgentServerCommand, root_dir: &Path, cx: &mut AsyncApp, ) -> Result> { - let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await; + let conn = v1::AcpConnection::stdio(server_name.clone(), command.clone(), &root_dir, cx).await; match conn { Ok(conn) => Ok(Rc::new(conn) as _), diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 74647f7313..72a0283ea2 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -10,7 +10,7 @@ use ui::App; use util::ResultExt as _; use crate::AgentServerCommand; -use acp_thread::{AcpThread, AgentConnection, AuthRequired}; +use acp_thread::{AcpThread, AgentConnection, AgentServerName, AuthRequired}; #[derive(Clone)] struct OldAcpClientDelegate { @@ -354,7 +354,7 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu } pub struct AcpConnection { - pub name: &'static str, + pub name: AgentServerName, pub connection: acp_old::AgentConnection, pub _child_status: Task>, pub current_thread: Rc>>, @@ -362,7 +362,7 @@ pub struct AcpConnection { impl AcpConnection { pub fn stdio( - name: &'static str, + name: AgentServerName, command: AgentServerCommand, root_dir: &Path, cx: &mut AsyncApp, @@ -443,7 +443,7 @@ impl AgentConnection for AcpConnection { cx.update(|cx| { let thread = cx.new(|cx| { let session_id = acp::SessionId("acp-old-no-id".into()); - AcpThread::new(self.name, self.clone(), project, session_id, cx) + AcpThread::new(self.name.0.clone(), self.clone(), project, session_id, cx) }); current_thread.replace(thread.downgrade()); thread diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index b77b5ef36d..af56249381 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -13,10 +13,10 @@ use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use crate::{AgentServerCommand, acp::UnsupportedVersion}; -use acp_thread::{AcpThread, AgentConnection, AuthRequired}; +use acp_thread::{AcpThread, AgentConnection, AgentServerName, AuthRequired}; pub struct AcpConnection { - server_name: &'static str, + server_name: AgentServerName, connection: Rc, sessions: Rc>>, auth_methods: Vec, @@ -31,7 +31,7 @@ const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1; impl AcpConnection { pub async fn stdio( - server_name: &'static str, + server_name: AgentServerName, command: AgentServerCommand, root_dir: &Path, cx: &mut AsyncApp, @@ -150,7 +150,7 @@ impl AgentConnection for AcpConnection { let thread = cx.new(|cx| { AcpThread::new( - self.server_name, + self.server_name.0.clone(), self.clone(), project, session_id.clone(), diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index b3b8a33170..0836af1ba9 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -10,7 +10,7 @@ pub use claude::*; pub use gemini::*; pub use settings::*; -use acp_thread::AgentConnection; +use acp_thread::{AgentConnection, AgentServerName}; use anyhow::Result; use collections::HashMap; use gpui::{App, AsyncApp, Entity, SharedString, Task}; @@ -30,7 +30,7 @@ pub fn init(cx: &mut App) { pub trait AgentServer: Send { fn logo(&self) -> ui::IconName; - fn name(&self) -> &'static str; + fn name(&self) -> AgentServerName; fn empty_state_headline(&self) -> &'static str; fn empty_state_message(&self) -> &'static str; diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index d15cc1dd89..afc07b720f 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -30,18 +30,18 @@ use util::{ResultExt, debug_panic}; use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::tools::ClaudeTool; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::{AcpThread, AgentConnection}; +use acp_thread::{AcpThread, AgentConnection, AgentServerName}; #[derive(Clone)] pub struct ClaudeCode; impl AgentServer for ClaudeCode { - fn name(&self) -> &'static str { - "Claude Code" + fn name(&self) -> AgentServerName { + AgentServerName("Claude Code".into()) } fn empty_state_headline(&self) -> &'static str { - self.name() + "Claude Code" } fn empty_state_message(&self) -> &'static str { diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index ad883f6da8..ab428fe5b3 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -2,7 +2,7 @@ use std::path::Path; use std::rc::Rc; use crate::{AgentServer, AgentServerCommand}; -use acp_thread::{AgentConnection, LoadError}; +use acp_thread::{AgentConnection, AgentServerName, LoadError}; use anyhow::Result; use gpui::{Entity, Task}; use project::Project; @@ -17,8 +17,8 @@ pub struct Gemini; const ACP_ARG: &str = "--experimental-acp"; impl AgentServer for Gemini { - fn name(&self) -> &'static str { - "Gemini" + fn name(&self) -> AgentServerName { + AgentServerName("Gemini".into()) } fn empty_state_headline(&self) -> &'static str { diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs index 831d296eeb..efdeee9efd 100644 --- a/crates/agent_ui/src/acp.rs +++ b/crates/agent_ui/src/acp.rs @@ -3,8 +3,10 @@ mod entry_view_state; mod message_editor; mod model_selector; mod model_selector_popover; +mod thread_history; mod thread_view; pub use model_selector::AcpModelSelector; pub use model_selector_popover::AcpModelSelectorPopover; +pub use thread_history::{AcpThreadHistory, ThreadHistoryEvent}; pub use thread_view::AcpThreadView; diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs new file mode 100644 index 0000000000..d0bf60ad72 --- /dev/null +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -0,0 +1,796 @@ +use crate::RemoveSelectedThread; +use agent_servers::AgentServer; +use agent2::{HistoryEntry, HistoryStore, NativeAgentServer}; +use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; +use editor::{Editor, EditorEvent}; +use fuzzy::{StringMatch, StringMatchCandidate}; +use gpui::{ + App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ScrollStrategy, Stateful, + Task, UniformListScrollHandle, Window, uniform_list, +}; +use project::Project; +use std::{fmt::Display, ops::Range, sync::Arc}; +use time::{OffsetDateTime, UtcOffset}; +use ui::{ + HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState, + Tooltip, prelude::*, +}; +use util::ResultExt; + +pub struct AcpThreadHistory { + pub(crate) history_store: Entity, + scroll_handle: UniformListScrollHandle, + selected_index: usize, + hovered_index: Option, + search_editor: Entity, + all_entries: Arc>, + // When the search is empty, we display date separators between history entries + // This vector contains an enum of either a separator or an actual entry + separated_items: Vec, + // Maps entry indexes to list item indexes + separated_item_indexes: Vec, + _separated_items_task: Option>, + search_state: SearchState, + scrollbar_visibility: bool, + scrollbar_state: ScrollbarState, + local_timezone: UtcOffset, + _subscriptions: Vec, +} + +enum SearchState { + Empty, + Searching { + query: SharedString, + _task: Task<()>, + }, + Searched { + query: SharedString, + matches: Vec, + }, +} + +enum ListItemType { + BucketSeparator(TimeBucket), + Entry { + index: usize, + format: EntryTimeFormat, + }, +} + +pub enum ThreadHistoryEvent { + Open(HistoryEntry), +} + +impl EventEmitter for AcpThreadHistory {} + +impl AcpThreadHistory { + pub(crate) fn new( + project: &Entity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let search_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text("Search threads...", cx); + editor + }); + let history_store = HistoryStore::get_or_init(project, cx); + + let search_editor_subscription = + cx.subscribe(&search_editor, |this, search_editor, event, cx| { + if let EditorEvent::BufferEdited = event { + let query = search_editor.read(cx).text(cx); + this.search(query.into(), cx); + } + }); + + let history_store_subscription = cx.observe(&history_store, |this, _, cx| { + this.update_all_entries(cx); + }); + + let scroll_handle = UniformListScrollHandle::default(); + let scrollbar_state = ScrollbarState::new(scroll_handle.clone()); + + let mut this = Self { + history_store, + scroll_handle, + selected_index: 0, + hovered_index: None, + search_state: SearchState::Empty, + all_entries: Default::default(), + separated_items: Default::default(), + separated_item_indexes: Default::default(), + search_editor, + scrollbar_visibility: true, + scrollbar_state, + local_timezone: UtcOffset::from_whole_seconds( + chrono::Local::now().offset().local_minus_utc(), + ) + .unwrap(), + _subscriptions: vec![search_editor_subscription, history_store_subscription], + _separated_items_task: None, + }; + this.update_all_entries(cx); + this + } + + fn update_all_entries(&mut self, cx: &mut Context) { + let new_entries: Arc> = self + .history_store + .update(cx, |store, cx| store.entries(cx)) + .into(); + + self._separated_items_task.take(); + + let mut items = Vec::with_capacity(new_entries.len() + 1); + let mut indexes = Vec::with_capacity(new_entries.len() + 1); + + let bg_task = cx.background_spawn(async move { + let mut bucket = None; + let today = Local::now().naive_local().date(); + + for (index, entry) in new_entries.iter().enumerate() { + let entry_date = entry + .updated_at() + .with_timezone(&Local) + .naive_local() + .date(); + let entry_bucket = TimeBucket::from_dates(today, entry_date); + + if Some(entry_bucket) != bucket { + bucket = Some(entry_bucket); + items.push(ListItemType::BucketSeparator(entry_bucket)); + } + + indexes.push(items.len() as u32); + items.push(ListItemType::Entry { + index, + format: entry_bucket.into(), + }); + } + (new_entries, items, indexes) + }); + + let task = cx.spawn(async move |this, cx| { + let (new_entries, items, indexes) = bg_task.await; + this.update(cx, |this, cx| { + let previously_selected_entry = + this.all_entries.get(this.selected_index).map(|e| e.id()); + + this.all_entries = new_entries; + this.separated_items = items; + this.separated_item_indexes = indexes; + + match &this.search_state { + SearchState::Empty => { + if this.selected_index >= this.all_entries.len() { + this.set_selected_entry_index( + this.all_entries.len().saturating_sub(1), + cx, + ); + } else if let Some(prev_id) = previously_selected_entry { + if let Some(new_ix) = this + .all_entries + .iter() + .position(|probe| probe.id() == prev_id) + { + this.set_selected_entry_index(new_ix, cx); + } + } + } + SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => { + this.search(query.clone(), cx); + } + } + + cx.notify(); + }) + .log_err(); + }); + self._separated_items_task = Some(task); + } + + fn search(&mut self, query: SharedString, cx: &mut Context) { + if query.is_empty() { + self.search_state = SearchState::Empty; + cx.notify(); + return; + } + + let all_entries = self.all_entries.clone(); + + let fuzzy_search_task = cx.background_spawn({ + let query = query.clone(); + let executor = cx.background_executor().clone(); + async move { + let mut candidates = Vec::with_capacity(all_entries.len()); + + for (idx, entry) in all_entries.iter().enumerate() { + match entry { + HistoryEntry::AcpThread(thread) => { + candidates.push(StringMatchCandidate::new(idx, &thread.title)); + } + HistoryEntry::TextThread(context) => { + candidates.push(StringMatchCandidate::new(idx, &context.title)); + } + } + } + + const MAX_MATCHES: usize = 100; + + fuzzy::match_strings( + &candidates, + &query, + false, + true, + MAX_MATCHES, + &Default::default(), + executor, + ) + .await + } + }); + + let task = cx.spawn({ + let query = query.clone(); + async move |this, cx| { + let matches = fuzzy_search_task.await; + + this.update(cx, |this, cx| { + let SearchState::Searching { + query: current_query, + _task, + } = &this.search_state + else { + return; + }; + + if &query == current_query { + this.search_state = SearchState::Searched { + query: query.clone(), + matches, + }; + + this.set_selected_entry_index(0, cx); + cx.notify(); + }; + }) + .log_err(); + } + }); + + self.search_state = SearchState::Searching { query, _task: task }; + cx.notify(); + } + + fn matched_count(&self) -> usize { + match &self.search_state { + SearchState::Empty => self.all_entries.len(), + SearchState::Searching { .. } => 0, + SearchState::Searched { matches, .. } => matches.len(), + } + } + + fn list_item_count(&self) -> usize { + match &self.search_state { + SearchState::Empty => self.separated_items.len(), + SearchState::Searching { .. } => 0, + SearchState::Searched { matches, .. } => matches.len(), + } + } + + fn search_produced_no_matches(&self) -> bool { + match &self.search_state { + SearchState::Empty => false, + SearchState::Searching { .. } => false, + SearchState::Searched { matches, .. } => matches.is_empty(), + } + } + + fn get_match(&self, ix: usize) -> Option<&HistoryEntry> { + match &self.search_state { + SearchState::Empty => self.all_entries.get(ix), + SearchState::Searching { .. } => None, + SearchState::Searched { matches, .. } => matches + .get(ix) + .and_then(|m| self.all_entries.get(m.candidate_id)), + } + } + + pub fn select_previous( + &mut self, + _: &menu::SelectPrevious, + _window: &mut Window, + cx: &mut Context, + ) { + let count = self.matched_count(); + if count > 0 { + if self.selected_index == 0 { + self.set_selected_entry_index(count - 1, cx); + } else { + self.set_selected_entry_index(self.selected_index - 1, cx); + } + } + } + + pub fn select_next( + &mut self, + _: &menu::SelectNext, + _window: &mut Window, + cx: &mut Context, + ) { + let count = self.matched_count(); + if count > 0 { + if self.selected_index == count - 1 { + self.set_selected_entry_index(0, cx); + } else { + self.set_selected_entry_index(self.selected_index + 1, cx); + } + } + } + + fn select_first( + &mut self, + _: &menu::SelectFirst, + _window: &mut Window, + cx: &mut Context, + ) { + let count = self.matched_count(); + if count > 0 { + self.set_selected_entry_index(0, cx); + } + } + + fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context) { + let count = self.matched_count(); + if count > 0 { + self.set_selected_entry_index(count - 1, cx); + } + } + + fn set_selected_entry_index(&mut self, entry_index: usize, cx: &mut Context) { + self.selected_index = entry_index; + + let scroll_ix = match self.search_state { + SearchState::Empty | SearchState::Searching { .. } => self + .separated_item_indexes + .get(entry_index) + .map(|ix| *ix as usize) + .unwrap_or(entry_index + 1), + SearchState::Searched { .. } => entry_index, + }; + + self.scroll_handle + .scroll_to_item(scroll_ix, ScrollStrategy::Top); + + cx.notify(); + } + + fn render_scrollbar(&self, cx: &mut Context) -> Option> { + if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) { + return None; + } + + Some( + div() + .occlude() + .id("thread-history-scroll") + .h_full() + .bg(cx.theme().colors().panel_background.opacity(0.8)) + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .absolute() + .right_1() + .top_0() + .bottom_0() + .w_4() + .pl_1() + .cursor_default() + .on_mouse_move(cx.listener(|_, _, _window, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _window, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _window, cx| { + cx.stop_propagation(); + }) + .on_scroll_wheel(cx.listener(|_, _, _window, cx| { + cx.notify(); + })) + .children(Scrollbar::vertical(self.scrollbar_state.clone())), + ) + } + + fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context) { + self.confirm_entry(self.selected_index, cx); + } + + fn confirm_entry(&mut self, ix: usize, cx: &mut Context) { + let Some(entry) = self.get_match(ix) else { + return; + }; + cx.emit(ThreadHistoryEvent::Open(entry.clone())); + // let task_result = match entry { + // HistoryEntry::Thread(thread) => { + // self.agent_panel.update(cx, move |agent_panel, cx| todo!()) + // } + // HistoryEntry::Context(context) => { + // self.agent_panel.update(cx, move |agent_panel, cx| { + // agent_panel.open_saved_prompt_editor(context.path.clone(), window, cx) + // }) + // } + // }; + + // if let Some(task) = task_result.log_err() { + // task.detach_and_log_err(cx); + // }; + + cx.notify(); + } + + fn remove_selected_thread( + &mut self, + _: &RemoveSelectedThread, + _window: &mut Window, + cx: &mut Context, + ) { + self.remove_thread(self.selected_index, cx) + } + + fn remove_thread(&mut self, ix: usize, cx: &mut Context) { + let Some(entry) = self.get_match(ix) else { + return; + }; + todo!(); + // let task_result = match entry { + // HistoryEntry::Thread(thread) => todo!(), + // HistoryEntry::Context(context) => self + // .agent_panel + // .update(cx, |this, cx| this.delete_context(context.path.clone(), cx)), + // }; + + // if let Some(task) = task_result.log_err() { + // task.detach_and_log_err(cx); + // }; + + cx.notify(); + } + + fn list_items( + &mut self, + range: Range, + _window: &mut Window, + cx: &mut Context, + ) -> Vec { + match &self.search_state { + SearchState::Empty => self + .separated_items + .get(range) + .iter() + .flat_map(|items| { + items + .iter() + .map(|item| self.render_list_item(item, vec![], cx)) + }) + .collect(), + SearchState::Searched { matches, .. } => matches[range] + .iter() + .filter_map(|m| { + let entry = self.all_entries.get(m.candidate_id)?; + Some(self.render_history_entry( + entry, + EntryTimeFormat::DateAndTime, + m.candidate_id, + m.positions.clone(), + cx, + )) + }) + .collect(), + SearchState::Searching { .. } => { + vec![] + } + } + } + + fn render_list_item( + &self, + item: &ListItemType, + highlight_positions: Vec, + cx: &Context, + ) -> AnyElement { + match item { + ListItemType::Entry { index, format } => match self.all_entries.get(*index) { + Some(entry) => self + .render_history_entry(entry, *format, *index, highlight_positions, cx) + .into_any(), + None => Empty.into_any_element(), + }, + ListItemType::BucketSeparator(bucket) => div() + .px(DynamicSpacing::Base06.rems(cx)) + .pt_2() + .pb_1() + .child( + Label::new(bucket.to_string()) + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .into_any_element(), + } + } + + fn render_history_entry( + &self, + entry: &HistoryEntry, + format: EntryTimeFormat, + list_entry_ix: usize, + highlight_positions: Vec, + cx: &Context, + ) -> AnyElement { + let selected = list_entry_ix == self.selected_index; + let hovered = Some(list_entry_ix) == self.hovered_index; + let timestamp = entry.updated_at().timestamp(); + let thread_timestamp = format.format_timestamp(timestamp, self.local_timezone); + + h_flex() + .w_full() + .pb_1() + .child( + ListItem::new(list_entry_ix) + .rounded() + .toggle_state(selected) + .spacing(ListItemSpacing::Sparse) + .start_slot( + h_flex() + .w_full() + .gap_2() + .justify_between() + .child( + HighlightedLabel::new(entry.title(), highlight_positions) + .size(LabelSize::Small) + .truncate(), + ) + .child( + Label::new(thread_timestamp) + .color(Color::Muted) + .size(LabelSize::XSmall), + ), + ) + .on_hover(cx.listener(move |this, is_hovered, _window, cx| { + if *is_hovered { + this.hovered_index = Some(list_entry_ix); + } else if this.hovered_index == Some(list_entry_ix) { + this.hovered_index = None; + } + + cx.notify(); + })) + .end_slot::(if hovered || selected { + Some( + IconButton::new("delete", IconName::Trash) + .shape(IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .tooltip(move |window, cx| { + Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx) + }) + .on_click(cx.listener(move |this, _, _, cx| { + this.remove_thread(list_entry_ix, cx) + })), + ) + } else { + None + }) + .on_click( + cx.listener(move |this, _, _, cx| this.confirm_entry(list_entry_ix, cx)), + ), + ) + .into_any_element() + } +} + +impl Focusable for AcpThreadHistory { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.search_editor.focus_handle(cx) + } +} + +impl Render for AcpThreadHistory { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .key_context("ThreadHistory") + .size_full() + .on_action(cx.listener(Self::select_previous)) + .on_action(cx.listener(Self::select_next)) + .on_action(cx.listener(Self::select_first)) + .on_action(cx.listener(Self::select_last)) + .on_action(cx.listener(Self::confirm)) + .on_action(cx.listener(Self::remove_selected_thread)) + .when(!self.all_entries.is_empty(), |parent| { + parent.child( + h_flex() + .h(px(41.)) // Match the toolbar perfectly + .w_full() + .py_1() + .px_2() + .gap_2() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + Icon::new(IconName::MagnifyingGlass) + .color(Color::Muted) + .size(IconSize::Small), + ) + .child(self.search_editor.clone()), + ) + }) + .child({ + let view = v_flex() + .id("list-container") + .relative() + .overflow_hidden() + .flex_grow(); + + if self.all_entries.is_empty() { + view.justify_center() + .child( + h_flex().w_full().justify_center().child( + Label::new("You don't have any past threads yet.") + .size(LabelSize::Small), + ), + ) + } else if self.search_produced_no_matches() { + view.justify_center().child( + h_flex().w_full().justify_center().child( + Label::new("No threads match your search.").size(LabelSize::Small), + ), + ) + } else { + view.pr_5() + .child( + uniform_list( + "thread-history", + self.list_item_count(), + cx.processor(|this, range: Range, window, cx| { + this.list_items(range, window, cx) + }), + ) + .p_1() + .track_scroll(self.scroll_handle.clone()) + .flex_grow(), + ) + .when_some(self.render_scrollbar(cx), |div, scrollbar| { + div.child(scrollbar) + }) + } + }) + } +} + +#[derive(Clone, Copy)] +pub enum EntryTimeFormat { + DateAndTime, + TimeOnly, +} + +impl EntryTimeFormat { + fn format_timestamp(&self, timestamp: i64, timezone: UtcOffset) -> String { + let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap(); + + match self { + EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp( + timestamp, + OffsetDateTime::now_utc(), + timezone, + time_format::TimestampFormat::EnhancedAbsolute, + ), + EntryTimeFormat::TimeOnly => time_format::format_time(timestamp), + } + } +} + +impl From for EntryTimeFormat { + fn from(bucket: TimeBucket) -> Self { + match bucket { + TimeBucket::Today => EntryTimeFormat::TimeOnly, + TimeBucket::Yesterday => EntryTimeFormat::TimeOnly, + TimeBucket::ThisWeek => EntryTimeFormat::DateAndTime, + TimeBucket::PastWeek => EntryTimeFormat::DateAndTime, + TimeBucket::All => EntryTimeFormat::DateAndTime, + } + } +} + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +enum TimeBucket { + Today, + Yesterday, + ThisWeek, + PastWeek, + All, +} + +impl TimeBucket { + fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self { + if date == reference { + return TimeBucket::Today; + } + + if date == reference - TimeDelta::days(1) { + return TimeBucket::Yesterday; + } + + let week = date.iso_week(); + + if reference.iso_week() == week { + return TimeBucket::ThisWeek; + } + + let last_week = (reference - TimeDelta::days(7)).iso_week(); + + if week == last_week { + return TimeBucket::PastWeek; + } + + TimeBucket::All + } +} + +impl Display for TimeBucket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TimeBucket::Today => write!(f, "Today"), + TimeBucket::Yesterday => write!(f, "Yesterday"), + TimeBucket::ThisWeek => write!(f, "This Week"), + TimeBucket::PastWeek => write!(f, "Past Week"), + TimeBucket::All => write!(f, "All"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::NaiveDate; + + #[test] + fn test_time_bucket_from_dates() { + let today = NaiveDate::from_ymd_opt(2023, 1, 15).unwrap(); + + let date = today; + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Today); + + let date = NaiveDate::from_ymd_opt(2023, 1, 14).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Yesterday); + + let date = NaiveDate::from_ymd_opt(2023, 1, 13).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek); + + let date = NaiveDate::from_ymd_opt(2023, 1, 11).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek); + + let date = NaiveDate::from_ymd_opt(2023, 1, 8).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek); + + let date = NaiveDate::from_ymd_opt(2023, 1, 5).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek); + + // All: not in this week or last week + let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::All); + + // Test year boundary cases + let new_year = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); + + let date = NaiveDate::from_ymd_opt(2022, 12, 31).unwrap(); + assert_eq!( + TimeBucket::from_dates(new_year, date), + TimeBucket::Yesterday + ); + + let date = NaiveDate::from_ymd_opt(2022, 12, 28).unwrap(); + assert_eq!(TimeBucket::from_dates(new_year, date), TimeBucket::ThisWeek); + } +} diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 2c02027c4d..676739da3b 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,6 +1,7 @@ use acp_thread::{ - AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, - LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId, + AcpThread, AcpThreadEvent, AcpThreadMetadata, AgentThreadEntry, AssistantMessage, + AssistantMessageChunk, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, + ToolCallStatus, UserMessageId, }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; @@ -17,6 +18,7 @@ use editor::scroll::Autoscroll; use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; use fs::Fs; +use futures::StreamExt; use gpui::{ Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, @@ -122,6 +124,7 @@ pub struct AcpThreadView { editor_expanded: bool, terminal_expanded: bool, editing_message: Option, + history_store: Entity, _cancel_task: Option>, _subscriptions: [Subscription; 3], } @@ -133,6 +136,7 @@ enum ThreadState { Ready { thread: Entity, _subscription: [Subscription; 2], + _history_task: Option>, }, LoadError(LoadError), Unauthenticated { @@ -148,8 +152,10 @@ impl AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, + history_store: Entity, thread_store: Entity, text_thread_store: Entity, + restore_thread: Option, window: &mut Window, cx: &mut Context, ) -> Self { @@ -191,7 +197,16 @@ impl AcpThreadView { workspace: workspace.clone(), project: project.clone(), entry_view_state, - thread_state: Self::initial_state(agent, workspace, project, window, cx), + thread_state: Self::initial_state( + agent, + restore_thread, + history_store.clone(), + workspace, + project, + window, + cx, + ), + history_store, message_editor, model_selector: None, profile_selector: None, @@ -215,6 +230,8 @@ impl AcpThreadView { fn initial_state( agent: Rc, + restore_thread: Option, + history_store: Entity, workspace: WeakEntity, project: Entity, window: &mut Window, @@ -241,6 +258,25 @@ impl AcpThreadView { } }; + let mut history_task = None; + let history = connection.clone().history(); + if let Some(history) = history.clone() { + if let Some(mut history) = cx.update(|_, cx| history.observe_history(cx)).ok() { + history_task = Some(cx.spawn(async move |cx| { + while let Some(update) = history.next().await { + if !history_store + .update(cx, |history_store, cx| { + history_store.update_history(update, cx) + }) + .is_ok() + { + break; + } + } + })); + } + } + // this.update_in(cx, |_this, _window, cx| { // let status = connection.exit_status(cx); // cx.spawn(async move |this, cx| { @@ -254,19 +290,24 @@ impl AcpThreadView { // .detach(); // }) // .ok(); - - let Some(result) = cx - .update(|_, cx| { + let history = connection.clone().history(); + let task = cx.update(|_, cx| { + if let Some(restore_thread) = restore_thread + && let Some(history) = history + { + history.load_thread(project.clone(), &root_dir, restore_thread.id, cx) + } else { connection .clone() .new_thread(project.clone(), &root_dir, cx) - }) - .log_err() - else { + } + }); + + let Ok(task) = task else { return; }; - let result = match result.await { + let result = match task.await { Err(e) => { let mut cx = cx.clone(); if e.is::() { @@ -293,8 +334,13 @@ impl AcpThreadView { let action_log_subscription = cx.observe(&action_log, |_, _, cx| cx.notify()); - this.list_state - .splice(0..0, thread.read(cx).entries().len()); + let count = thread.read(cx).entries().len(); + this.list_state.splice(0..0, count); + this.entry_view_state.update(cx, |view_state, cx| { + for ix in 0..count { + view_state.sync_entry(ix, &thread, window, cx); + } + }); AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); @@ -319,6 +365,7 @@ impl AcpThreadView { this.thread_state = ThreadState::Ready { thread, _subscription: [thread_subscription, action_log_subscription], + _history_task: history_task, }; this.profile_selector = this.as_native_thread(cx).map(|thread| { @@ -698,6 +745,7 @@ impl AcpThreadView { AcpThreadEvent::ServerExited(status) => { self.thread_state = ThreadState::ServerExited { status: *status }; } + AcpThreadEvent::TitleUpdated => {} } cx.notify(); } @@ -726,6 +774,8 @@ impl AcpThreadView { } else { this.thread_state = Self::initial_state( agent, + None, // todo!() + this.history_store.clone(), this.workspace.clone(), project.clone(), window, @@ -2546,12 +2596,15 @@ impl AcpThreadView { return; }; - thread.update(cx, |thread, _cx| { + thread.update(cx, |thread, cx| { let current_mode = thread.completion_mode(); - thread.set_completion_mode(match current_mode { - CompletionMode::Burn => CompletionMode::Normal, - CompletionMode::Normal => CompletionMode::Burn, - }); + thread.set_completion_mode( + match current_mode { + CompletionMode::Burn => CompletionMode::Normal, + CompletionMode::Normal => CompletionMode::Burn, + }, + cx, + ); }); } @@ -3265,8 +3318,8 @@ impl AcpThreadView { .tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use.")) .on_click({ cx.listener(move |this, _, _window, cx| { - thread.update(cx, |thread, _cx| { - thread.set_completion_mode(CompletionMode::Burn); + thread.update(cx, |thread, cx| { + thread.set_completion_mode(CompletionMode::Burn, cx); }); this.resume_chat(cx); }) @@ -3587,7 +3640,7 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { #[cfg(test)] pub(crate) mod tests { - use acp_thread::StubAgentConnection; + use acp_thread::{AgentServerName, StubAgentConnection}; use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol::SessionId; use editor::EditorSettings; @@ -3727,6 +3780,8 @@ pub(crate) mod tests { cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx))); let text_thread_store = cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); + let history_store = + cx.update(|_window, cx| cx.new(|cx| agent2::HistoryStore::get_or_init(cx))); let thread_view = cx.update(|window, cx| { cx.new(|cx| { @@ -3734,8 +3789,10 @@ pub(crate) mod tests { Rc::new(agent), workspace.downgrade(), project, + history_store.clone(), thread_store.clone(), text_thread_store.clone(), + None, window, cx, ) @@ -3817,8 +3874,8 @@ pub(crate) mod tests { ui::IconName::Ai } - fn name(&self) -> &'static str { - "Test" + fn name(&self) -> AgentServerName { + AgentServerName("Test".into()) } fn empty_state_headline(&self) -> &'static str { @@ -3925,6 +3982,8 @@ pub(crate) mod tests { cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx))); let text_thread_store = cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); + let history_store = + cx.update(|_window, cx| cx.new(|cx| agent2::HistoryStore::get_or_init(cx))); let connection = Rc::new(StubAgentConnection::new()); let thread_view = cx.update(|window, cx| { @@ -3933,8 +3992,10 @@ pub(crate) mod tests { Rc::new(StubAgentServer::new(connection.as_ref().clone())), workspace.downgrade(), project.clone(), + history_store, thread_store.clone(), text_thread_store.clone(), + None, window, cx, ) diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index b9e1ea5d0a..3d43c6883d 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -199,24 +199,21 @@ impl AgentDiffPane { let action_log = thread.action_log(cx).clone(); let mut this = Self { - _subscriptions: [ - Some( - cx.observe_in(&action_log, window, |this, _action_log, window, cx| { - this.update_excerpts(window, cx) - }), - ), + _subscriptions: vec![ + cx.observe_in(&action_log, window, |this, _action_log, window, cx| { + this.update_excerpts(window, cx) + }), match &thread { - AgentDiffThread::Native(thread) => { - Some(cx.subscribe(&thread, |this, _thread, event, cx| { - this.handle_thread_event(event, cx) - })) - } - AgentDiffThread::AcpThread(_) => None, + AgentDiffThread::Native(thread) => cx + .subscribe(&thread, |this, _thread, event, cx| { + this.handle_native_thread_event(event, cx) + }), + AgentDiffThread::AcpThread(thread) => cx + .subscribe(&thread, |this, _thread, event, cx| { + this.handle_acp_thread_event(event, cx) + }), }, - ] - .into_iter() - .flatten() - .collect(), + ], title: SharedString::default(), multibuffer, editor, @@ -324,13 +321,20 @@ impl AgentDiffPane { } } - fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { + fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { match event { ThreadEvent::SummaryGenerated => self.update_title(cx), _ => {} } } + fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context) { + match event { + AcpThreadEvent::TitleUpdated => self.update_title(cx), + _ => {} + } + } + pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) { if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) { self.editor.update(cx, |editor, cx| { @@ -1521,7 +1525,8 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } } - AcpThreadEvent::EntriesRemoved(_) + AcpThreadEvent::TitleUpdated + | AcpThreadEvent::EntriesRemoved(_) | AcpThreadEvent::Stopped | AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::Error diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 4cb231f357..8392c5589b 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -4,11 +4,13 @@ use std::rc::Rc; use std::sync::Arc; use std::time::Duration; +use acp_thread::AcpThreadMetadata; use agent_servers::AgentServer; +use agent2::HistoryEntry; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use serde::{Deserialize, Serialize}; -use crate::NewExternalAgentThread; +use crate::acp::{AcpThreadHistory, ThreadHistoryEvent}; use crate::agent_diff::AgentDiffThread; use crate::{ AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode, @@ -29,6 +31,7 @@ use crate::{ thread_history::{HistoryEntryElement, ThreadHistory}, ui::{AgentOnboardingModal, EndTrialUpsell}, }; +use crate::{ExternalAgent, NewExternalAgentThread}; use agent::{ Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, context_store::ContextStore, @@ -119,7 +122,7 @@ pub fn init(cx: &mut App) { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| { - panel.new_external_thread(action.agent, window, cx) + panel.new_external_thread(action.agent, None, window, cx) }); } }) @@ -478,6 +481,7 @@ pub struct AgentPanel { previous_view: Option, history_store: Entity, history: Entity, + acp_history: Entity, hovered_recent_history_item: Option, new_thread_menu_handle: PopoverMenuHandle, agent_panel_menu_handle: PopoverMenuHandle, @@ -744,6 +748,27 @@ impl AgentPanel { ) }); + let acp_history = cx.new(|cx| AcpThreadHistory::new(&project, window, cx)); + cx.subscribe_in( + &acp_history, + window, + |this, _, event, window, cx| match event { + ThreadHistoryEvent::Open(HistoryEntry::AcpThread(thread)) => { + let agent_choice = match thread.agent.0.as_ref() { + "Claude Code" => Some(ExternalAgent::ClaudeCode), + "Gemini" => Some(ExternalAgent::Gemini), + "Native Agent" => Some(ExternalAgent::NativeAgent), + _ => None, + }; + this.new_external_thread(agent_choice, Some(thread.clone()), window, cx); + } + ThreadHistoryEvent::Open(HistoryEntry::TextThread(thread)) => { + todo!() + } + }, + ) + .detach(); + Self { active_view, workspace, @@ -765,6 +790,7 @@ impl AgentPanel { previous_view: None, history_store: history_store.clone(), history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)), + acp_history, hovered_recent_history_item: None, new_thread_menu_handle: PopoverMenuHandle::default(), agent_panel_menu_handle: PopoverMenuHandle::default(), @@ -954,6 +980,7 @@ impl AgentPanel { fn new_external_thread( &mut self, agent_choice: Option, + restore_thread: Option, window: &mut Window, cx: &mut Context, ) { @@ -1004,13 +1031,16 @@ impl AgentPanel { }; this.update_in(cx, |this, window, cx| { + let acp_history_store = this.acp_history.read(cx).history_store.clone(); let thread_view = cx.new(|cx| { crate::acp::AcpThreadView::new( server, workspace.clone(), project, + acp_history_store, thread_store.clone(), text_thread_store.clone(), + restore_thread, window, cx, ) @@ -1669,13 +1699,13 @@ impl AgentPanel { window.dispatch_action(NewTextThread.boxed_clone(), cx); } AgentType::NativeAgent => { - self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), window, cx) + self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), None, window, cx) } AgentType::Gemini => { - self.new_external_thread(Some(crate::ExternalAgent::Gemini), window, cx) + self.new_external_thread(Some(crate::ExternalAgent::Gemini), None, window, cx) } AgentType::ClaudeCode => { - self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), window, cx) + self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), None, window, cx) } } } @@ -1686,7 +1716,14 @@ impl Focusable for AgentPanel { match &self.active_view { ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx), ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx), - ActiveView::History => self.history.focus_handle(cx), + ActiveView::History => { + if cx.has_flag::() { + self.acp_history.focus_handle(cx) + } else { + self.history.focus_handle(cx) + } + } + ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx), ActiveView::Configuration => { if let Some(configuration) = self.configuration.as_ref() { @@ -3517,7 +3554,13 @@ impl Render for AgentPanel { ActiveView::ExternalAgentThread { thread_view, .. } => parent .child(thread_view.clone()) .child(self.render_drag_target(cx)), - ActiveView::History => parent.child(self.history.clone()), + ActiveView::History => { + if cx.has_flag::() { + parent.child(self.acp_history.clone()) + } else { + parent.child(self.history.clone()) + } + } ActiveView::TextThread { context_editor, buffer_search_bar,