From 6b6b7e66e1466109917ef2d4b5c0ca5b3c219c1f Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 15 Aug 2025 17:00:26 +0200 Subject: [PATCH 01/25] Start on a new db module --- Cargo.lock | 1 + crates/acp_thread/src/connection.rs | 3 +- crates/acp_thread/src/mention.rs | 3 +- crates/agent2/Cargo.toml | 1 + crates/agent2/src/agent2.rs | 2 + crates/agent2/src/db.rs | 167 ++++++++++++++++++++++++++++ crates/agent2/src/thread.rs | 10 +- 7 files changed, 180 insertions(+), 7 deletions(-) create mode 100644 crates/agent2/src/db.rs diff --git a/Cargo.lock b/Cargo.lock index 2353733dc0..59299885d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -191,6 +191,7 @@ version = "0.1.0" dependencies = [ "acp_thread", "action_log", + "agent", "agent-client-protocol", "agent_servers", "agent_settings", diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index b2116020fb..7a67e885db 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -4,11 +4,12 @@ use anyhow::Result; use collections::IndexMap; use gpui::{Entity, SharedString, Task}; use project::Project; +use serde::{Deserialize, Serialize}; use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; use uuid::Uuid; -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct UserMessageId(Arc); impl UserMessageId { diff --git a/crates/acp_thread/src/mention.rs b/crates/acp_thread/src/mention.rs index b9b021c4ca..6686e9ebf1 100644 --- a/crates/acp_thread/src/mention.rs +++ b/crates/acp_thread/src/mention.rs @@ -2,6 +2,7 @@ use agent::ThreadId; use anyhow::{Context as _, Result, bail}; use file_icons::FileIcons; use prompt_store::{PromptId, UserPromptId}; +use serde::{Deserialize, Serialize}; use std::{ fmt, ops::Range, @@ -11,7 +12,7 @@ use std::{ use ui::{App, IconName, SharedString}; use url::Url; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum MentionUri { File { abs_path: PathBuf, diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index ac1840e5e5..60a327c73b 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -17,6 +17,7 @@ action_log.workspace = true agent-client-protocol.workspace = true agent_servers.workspace = true agent_settings.workspace = true +agent.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index f13cd1bd67..31e18e3769 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,4 +1,5 @@ mod agent; +mod db; mod native_agent_server; mod templates; mod thread; @@ -8,6 +9,7 @@ mod tools; mod tests; pub use agent::*; +pub use db::*; pub use native_agent_server::NativeAgentServer; pub use templates::*; pub use thread::*; diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs new file mode 100644 index 0000000000..e17a1121b3 --- /dev/null +++ b/crates/agent2/src/db.rs @@ -0,0 +1,167 @@ +use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; +use agent::thread_store; +use agent_settings::{AgentProfileId, CompletionMode}; +use anyhow::Result; +use chrono::{DateTime, Utc}; +use collections::{HashMap, IndexMap}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use ui::SharedString; + +pub type DbMessage = crate::Message; +pub type DbSummary = agent::thread::DetailedSummaryState; +pub type DbLanguageModel = thread_store::SerializedLanguageModel; +pub type DbThreadMetadata = thread_store::SerializedThreadMetadata; + +#[derive(Debug, Serialize, Deserialize)] +pub struct DbThread { + pub title: SharedString, + pub messages: Vec, + pub updated_at: DateTime, + #[serde(default)] + pub summary: DbSummary, + #[serde(default)] + pub initial_project_snapshot: Option>, + #[serde(default)] + pub cumulative_token_usage: language_model::TokenUsage, + #[serde(default)] + pub request_token_usage: Vec, + #[serde(default)] + pub model: Option, + #[serde(default)] + pub completion_mode: Option, + #[serde(default)] + pub profile: Option, +} + +impl DbThread { + pub const VERSION: &'static str = "0.3.0"; + + pub fn from_json(json: &[u8]) -> Result { + let saved_thread_json = serde_json::from_slice::(json)?; + match saved_thread_json.get("version") { + Some(serde_json::Value::String(version)) => match version.as_str() { + Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?), + _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?), + }, + _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?), + } + } + + fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result { + let mut messages = Vec::new(); + for msg in thread.messages { + let message = match msg.role { + language_model::Role::User => { + let mut content = Vec::new(); + + // Convert segments to content + for segment in msg.segments { + match segment { + thread_store::SerializedMessageSegment::Text { text } => { + content.push(UserMessageContent::Text(text)); + } + thread_store::SerializedMessageSegment::Thinking { text, .. } => { + // User messages don't have thinking segments, but handle gracefully + content.push(UserMessageContent::Text(text)); + } + thread_store::SerializedMessageSegment::RedactedThinking { .. } => { + // User messages don't have redacted thinking, skip. + } + } + } + + // If no content was added, add context as text if available + if content.is_empty() && !msg.context.is_empty() { + content.push(UserMessageContent::Text(msg.context)); + } + + crate::Message::User(UserMessage { + // MessageId from old format can't be meaningfully converted, so generate a new one + id: acp_thread::UserMessageId::new(), + content, + }) + } + language_model::Role::Assistant => { + let mut content = Vec::new(); + + // Convert segments to content + for segment in msg.segments { + match segment { + thread_store::SerializedMessageSegment::Text { text } => { + content.push(AgentMessageContent::Text(text)); + } + thread_store::SerializedMessageSegment::Thinking { + text, + signature, + } => { + content.push(AgentMessageContent::Thinking { text, signature }); + } + thread_store::SerializedMessageSegment::RedactedThinking { data } => { + content.push(AgentMessageContent::RedactedThinking(data)); + } + } + } + + // Convert tool uses + let mut tool_names_by_id = HashMap::default(); + for tool_use in msg.tool_uses { + tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone()); + content.push(AgentMessageContent::ToolUse( + language_model::LanguageModelToolUse { + id: tool_use.id, + name: tool_use.name.into(), + raw_input: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + input: tool_use.input, + is_input_complete: true, + }, + )); + } + + // Convert tool results + let mut tool_results = IndexMap::default(); + for tool_result in msg.tool_results { + let name = tool_names_by_id + .remove(&tool_result.tool_use_id) + .unwrap_or_else(|| SharedString::from("unknown")); + tool_results.insert( + tool_result.tool_use_id.clone(), + language_model::LanguageModelToolResult { + tool_use_id: tool_result.tool_use_id, + tool_name: name.into(), + is_error: tool_result.is_error, + content: tool_result.content, + output: tool_result.output, + }, + ); + } + + crate::Message::Agent(AgentMessage { + content, + tool_results, + }) + } + language_model::Role::System => { + // Skip system messages as they're not supported in the new format + continue; + } + }; + + messages.push(message); + } + + Ok(Self { + title: thread.summary, + messages, + updated_at: thread.updated_at, + summary: thread.detailed_summary_state, + initial_project_snapshot: thread.initial_project_snapshot, + cumulative_token_usage: thread.cumulative_token_usage, + request_token_usage: thread.request_token_usage, + model: thread.model, + completion_mode: thread.completion_mode, + profile: thread.profile, + }) + } +} diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 231ee92dda..784477d677 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -29,7 +29,7 @@ use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc}; use std::{fmt::Write, ops::Range}; use util::{ResultExt, markdown::MarkdownCodeBlock}; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Message { User(UserMessage), Agent(AgentMessage), @@ -53,13 +53,13 @@ impl Message { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UserMessage { pub id: UserMessageId, pub content: Vec, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum UserMessageContent { Text(String), Mention { uri: MentionUri, content: String }, @@ -376,13 +376,13 @@ impl AgentMessage { } } -#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AgentMessage { pub content: Vec, pub tool_results: IndexMap, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum AgentMessageContent { Text(String), Thinking { From fd8ea2acfc52376cdbe00fd517c2a53d6cea9048 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 15 Aug 2025 13:11:36 -0600 Subject: [PATCH 02/25] Test round-tripping old threads --- Cargo.lock | 3 + crates/agent/src/thread_store.rs | 2 +- crates/agent2/Cargo.toml | 4 + crates/agent2/src/db.rs | 314 ++++++++++++++++++++++++++++++- 4 files changed, 318 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 59299885d7..8e6f4b656a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -222,6 +222,7 @@ dependencies = [ "log", "lsp", "open", + "parking_lot", "paths", "portable-pty", "pretty_assertions", @@ -234,6 +235,7 @@ dependencies = [ "serde_json", "settings", "smol", + "sqlez", "task", "tempfile", "terminal", @@ -250,6 +252,7 @@ dependencies = [ "workspace-hack", "worktree", "zlog", + "zstd", ] [[package]] diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 12c94a522d..e24a5ec782 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -893,7 +893,7 @@ impl ThreadsDatabase { let needs_migration_from_heed = mdb_path.exists(); - let connection = if *ZED_STATELESS { + let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) { Connection::open_memory(Some("THREAD_FALLBACK_DB")) } else { Connection::open_file(&sqlite_path.to_string_lossy()) diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 60a327c73b..de1b1f74ec 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -38,6 +38,7 @@ language_model.workspace = true language_models.workspace = true log.workspace = true open.workspace = true +parking_lot.workspace = true paths.workspace = true portable-pty.workspace = true project.workspace = true @@ -47,6 +48,7 @@ schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +sqlez.workspace = true smol.workspace = true task.workspace = true terminal.workspace = true @@ -58,8 +60,10 @@ watch.workspace = true web_search.workspace = true which.workspace = true workspace-hack.workspace = true +zstd.workspace = true [dev-dependencies] +agent = { workspace = true, "features" = ["test-support"] } ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index e17a1121b3..e8f85b0346 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -1,17 +1,34 @@ use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; use agent::thread_store; +use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; -use anyhow::Result; +use anyhow::{Result, anyhow}; use chrono::{DateTime, Utc}; use collections::{HashMap, IndexMap}; +use futures::{FutureExt, future::Shared}; +use gpui::{BackgroundExecutor, Global, ReadGlobal, Task}; +use indoc::indoc; +use parking_lot::Mutex; use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use ui::SharedString; +use sqlez::{ + bindable::{Bind, Column}, + connection::Connection, + statement::Statement, +}; +use std::{path::PathBuf, sync::Arc}; +use ui::{App, SharedString}; pub type DbMessage = crate::Message; pub type DbSummary = agent::thread::DetailedSummaryState; pub type DbLanguageModel = thread_store::SerializedLanguageModel; -pub type DbThreadMetadata = thread_store::SerializedThreadMetadata; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbThreadMetadata { + pub id: acp::SessionId, + #[serde(alias = "summary")] + pub title: SharedString, + pub updated_at: DateTime, +} #[derive(Debug, Serialize, Deserialize)] pub struct DbThread { @@ -165,3 +182,292 @@ impl DbThread { }) } } + +pub static ZED_STATELESS: std::sync::LazyLock = + std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DataType { + #[serde(rename = "json")] + Json, + #[serde(rename = "zstd")] + Zstd, +} + +impl Bind for DataType { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let value = match self { + DataType::Json => "json", + DataType::Zstd => "zstd", + }; + value.bind(statement, start_index) + } +} + +impl Column for DataType { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (value, next_index) = String::column(statement, start_index)?; + let data_type = match value.as_str() { + "json" => DataType::Json, + "zstd" => DataType::Zstd, + _ => anyhow::bail!("Unknown data type: {}", value), + }; + Ok((data_type, next_index)) + } +} + +struct GlobalThreadsDatabase(Shared, Arc>>>); + +impl Global for GlobalThreadsDatabase {} + +pub(crate) struct ThreadsDatabase { + executor: BackgroundExecutor, + connection: Arc>, +} + +impl ThreadsDatabase { + fn connection(&self) -> Arc> { + self.connection.clone() + } + + const COMPRESSION_LEVEL: i32 = 3; +} + +impl ThreadsDatabase { + fn global_future( + cx: &mut App, + ) -> Shared, Arc>>> { + GlobalThreadsDatabase::global(cx).0.clone() + } + + fn init(cx: &mut App) { + let executor = cx.background_executor().clone(); + let database_future = executor + .spawn({ + let executor = executor.clone(); + let threads_dir = paths::data_dir().join("threads"); + async move { + match ThreadsDatabase::new(threads_dir, executor) { + Ok(db) => Ok(Arc::new(db)), + Err(err) => Err(Arc::new(err)), + } + } + }) + .shared(); + + cx.set_global(GlobalThreadsDatabase(database_future)); + } + + pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result { + let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) { + Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else { + std::fs::create_dir_all(&threads_dir)?; + let sqlite_path = threads_dir.join("threads.db"); + Connection::open_file(&sqlite_path.to_string_lossy()) + }; + + connection.exec(indoc! {" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + summary TEXT NOT NULL, + updated_at TEXT NOT NULL, + data_type TEXT NOT NULL, + data BLOB NOT NULL + ) + "})?() + .map_err(|e| anyhow!("Failed to create threads table: {}", e))?; + + let db = Self { + executor: executor.clone(), + connection: Arc::new(Mutex::new(connection)), + }; + + Ok(db) + } + + fn save_thread_sync( + connection: &Arc>, + id: acp::SessionId, + thread: DbThread, + ) -> Result<()> { + let json_data = serde_json::to_string(&thread)?; + let title = thread.title.to_string(); + let updated_at = thread.updated_at.to_rfc3339(); + + let connection = connection.lock(); + + let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?; + let data_type = DataType::Zstd; + let data = compressed; + + let mut insert = connection.exec_bound::<(Arc, String, String, DataType, Vec)>(indoc! {" + INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?) + "})?; + + insert((id.0, title, updated_at, data_type, data))?; + + Ok(()) + } + + pub fn list_threads(&self) -> Task>> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + let mut select = + connection.select_bound::<(), (Arc, String, String)>(indoc! {" + SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC + "})?; + + let rows = select(())?; + let mut threads = Vec::new(); + + for (id, summary, updated_at) in rows { + threads.push(DbThreadMetadata { + id: acp::SessionId(id), + title: summary.into(), + updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), + }); + } + + Ok(threads) + }) + } + + pub fn load_thread(&self, id: acp::SessionId) -> Task>> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + let mut select = connection.select_bound::, (DataType, Vec)>(indoc! {" + SELECT data_type, data FROM threads WHERE id = ? LIMIT 1 + "})?; + + let rows = select(id.0)?; + if let Some((data_type, data)) = rows.into_iter().next() { + let json_data = match data_type { + DataType::Zstd => { + let decompressed = zstd::decode_all(&data[..])?; + String::from_utf8(decompressed)? + } + DataType::Json => String::from_utf8(data)?, + }; + + let thread = DbThread::from_json(json_data.as_bytes())?; + Ok(Some(thread)) + } else { + Ok(None) + } + }) + } + + pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task> { + let connection = self.connection.clone(); + + self.executor + .spawn(async move { Self::save_thread_sync(&connection, id, thread) }) + } + + pub fn delete_thread(&self, id: acp::SessionId) -> Task> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + + let mut delete = connection.exec_bound::>(indoc! {" + DELETE FROM threads WHERE id = ? + "})?; + + delete(id.0)?; + + Ok(()) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use agent::MessageSegment; + use agent::context::LoadedContext; + use client::Client; + use fs::FakeFs; + use gpui::AppContext; + use gpui::TestAppContext; + use http_client::FakeHttpClient; + use language_model::Role; + use pretty_assertions::assert_matches; + use project::Project; + use settings::SettingsStore; + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + ThreadsDatabase::init(cx); + + let http_client = FakeHttpClient::with_404_response(); + let clock = Arc::new(clock::FakeSystemClock::new()); + let client = Client::new(clock, http_client, cx); + agent::init(cx); + agent_settings::init(cx); + language_model::init(client.clone(), cx); + }); + } + + #[gpui::test] + async fn test_retrieving_old_thread(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + + // Save a thread using the old agent. + { + let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx)); + let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx)); + thread.update(cx, |thread, cx| { + thread.insert_message( + Role::User, + vec![MessageSegment::Text("Hey!".into())], + LoadedContext::default(), + vec![], + false, + cx, + ); + thread.insert_message( + Role::Assistant, + vec![MessageSegment::Text("How're you doing?".into())], + LoadedContext::default(), + vec![], + false, + cx, + ) + }); + thread_store + .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) + .await + .unwrap(); + } + + let db = cx + .update(|cx| ThreadsDatabase::global_future(cx)) + .await + .unwrap(); + let threads = db.list_threads().await.unwrap(); + assert_eq!(threads.len(), 1); + let thread = db + .load_thread(threads[0].id.clone()) + .await + .unwrap() + .unwrap(); + assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n"); + assert_eq!( + thread.messages[1].to_markdown(), + "## Assistant\n\nHow're you doing?\n" + ); + } +} From 251baacdabf6bd13452daa947078c03856ef148c Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 15 Aug 2025 14:46:52 -0600 Subject: [PATCH 03/25] WIP --- Cargo.lock | 1 + crates/acp_thread/Cargo.toml | 1 + crates/acp_thread/src/acp_thread.rs | 17 ++++++++++++-- crates/acp_thread/src/connection.rs | 8 ++++++- crates/agent2/src/agent.rs | 23 ++++++++++++++++++- crates/agent2/src/db.rs | 31 ++++++-------------------- crates/agent_servers/src/acp/v0.rs | 6 ++++- crates/agent_servers/src/acp/v1.rs | 6 ++++- crates/agent_servers/src/claude.rs | 6 ++++- crates/agent_ui/src/acp/thread_view.rs | 6 ++++- crates/agent_ui/src/agent_ui.rs | 1 + crates/agent_ui/src/thread_history.rs | 2 +- 12 files changed, 75 insertions(+), 33 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8e6f4b656a..a20579cce3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,7 @@ dependencies = [ "agent-client-protocol", "anyhow", "buffer_diff", + "chrono", "collections", "editor", "env_logger 0.11.8", diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 2b9a6513c8..cbe74c1f37 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -21,6 +21,7 @@ agent-client-protocol.workspace = true agent.workspace = true anyhow.workspace = true buffer_diff.workspace = true +chrono.workspace = true collections.workspace = true editor.workspace = true file_icons.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 4995ddb9df..d8b0d5805a 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -6,11 +6,13 @@ mod terminal; pub use connection::*; pub use diff::*; pub use mention::*; +use serde::{Deserialize, Serialize}; pub use terminal::*; use action_log::ActionLog; use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow}; +use chrono::{DateTime, Utc}; use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; @@ -632,6 +634,13 @@ impl PlanEntry { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcpThreadMetadata { + pub id: acp::SessionId, + pub title: SharedString, + pub updated_at: DateTime, +} + pub struct AcpThread { title: SharedString, entries: Vec, @@ -1608,7 +1617,7 @@ mod tests { use super::*; use anyhow::anyhow; use futures::{channel::mpsc, future::LocalBoxFuture, select}; - use gpui::{AsyncApp, TestAppContext, WeakEntity}; + use gpui::{App, AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; use project::{FakeFs, Fs}; use rand::Rng as _; @@ -2284,7 +2293,7 @@ mod tests { self: Rc, project: Entity, _cwd: &Path, - cx: &mut gpui::App, + cx: &mut App, ) -> Task>> { let session_id = acp::SessionId( rand::thread_rng() @@ -2300,6 +2309,10 @@ mod tests { Task::ready(Ok(thread)) } + fn list_threads(&self, _: &mut App) -> Task>> { + unimplemented!() + } + fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task> { if self.auth_methods().iter().any(|m| m.id == method) { Task::ready(Ok(())) diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 7a67e885db..bc31ad0fe7 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,4 +1,4 @@ -use crate::AcpThread; +use crate::{AcpThread, AcpThreadMetadata}; use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; @@ -26,6 +26,8 @@ pub trait AgentConnection { cx: &mut App, ) -> Task>>; + fn list_threads(&self, _cx: &mut App) -> Task>>; + fn auth_methods(&self) -> &[acp::AuthMethod]; fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; @@ -264,6 +266,10 @@ mod test_support { unimplemented!() } + fn list_threads(&self, _: &mut App) -> Task>> { + unimplemented!() + } + fn prompt( &self, _id: Option, diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 358365d11f..2bcc24d00c 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,16 +1,18 @@ +use crate::ThreadsDatabase; use crate::{ AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, }; -use acp_thread::AgentModelSelector; +use acp_thread::{AcpThreadMetadata, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; use collections::{HashSet, IndexMap}; use fs::Fs; use futures::channel::mpsc; +use futures::future::Shared; use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, @@ -166,6 +168,7 @@ pub struct NativeAgent { models: LanguageModels, project: Entity, prompt_store: Option>, + thread_database: Shared, Arc>>>, fs: Arc, _subscriptions: Vec, } @@ -208,6 +211,7 @@ impl NativeAgent { context_server_registry: cx.new(|cx| { ContextServerRegistry::new(project.read(cx).context_server_store(), cx) }), + thread_database: ThreadsDatabase::connect(cx), templates, models: LanguageModels::new(cx), project, @@ -751,6 +755,23 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Task::ready(Ok(())) } + fn list_threads(&self, cx: &mut App) -> Task>> { + let database = self.0.read(cx).thread_database.clone(); + cx.background_executor().spawn(async move { + let database = database.await.map_err(|e| anyhow!(e))?; + let results = database.list_threads().await?; + + Ok(results + .into_iter() + .map(|thread| AcpThreadMetadata { + id: thread.id, + title: thread.title, + updated_at: thread.updated_at, + }) + .collect()) + }) + } + fn model_selector(&self) -> Option> { Some(Rc::new(self.clone()) as Rc) } diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index e8f85b0346..d40352f257 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -216,10 +216,6 @@ impl Column for DataType { } } -struct GlobalThreadsDatabase(Shared, Arc>>>); - -impl Global for GlobalThreadsDatabase {} - pub(crate) struct ThreadsDatabase { executor: BackgroundExecutor, connection: Arc>, @@ -234,34 +230,26 @@ impl ThreadsDatabase { } impl ThreadsDatabase { - fn global_future( - cx: &mut App, - ) -> Shared, Arc>>> { - GlobalThreadsDatabase::global(cx).0.clone() - } - - fn init(cx: &mut App) { + pub fn connect(cx: &mut App) -> Shared, Arc>>> { let executor = cx.background_executor().clone(); - let database_future = executor + executor .spawn({ let executor = executor.clone(); - let threads_dir = paths::data_dir().join("threads"); async move { - match ThreadsDatabase::new(threads_dir, executor) { + match ThreadsDatabase::new(executor) { Ok(db) => Ok(Arc::new(db)), Err(err) => Err(Arc::new(err)), } } }) - .shared(); - - cx.set_global(GlobalThreadsDatabase(database_future)); + .shared() } - pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result { + pub fn new(executor: BackgroundExecutor) -> Result { let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) { Connection::open_memory(Some("THREAD_FALLBACK_DB")) } else { + let threads_dir = paths::data_dir().join("threads"); std::fs::create_dir_all(&threads_dir)?; let sqlite_path = threads_dir.join("threads.db"); Connection::open_file(&sqlite_path.to_string_lossy()) @@ -397,7 +385,6 @@ mod tests { use gpui::TestAppContext; use http_client::FakeHttpClient; use language_model::Role; - use pretty_assertions::assert_matches; use project::Project; use settings::SettingsStore; @@ -408,7 +395,6 @@ mod tests { cx.set_global(settings_store); Project::init_settings(cx); language::init(cx); - ThreadsDatabase::init(cx); let http_client = FakeHttpClient::with_404_response(); let clock = Arc::new(clock::FakeSystemClock::new()); @@ -453,10 +439,7 @@ mod tests { .unwrap(); } - let db = cx - .update(|cx| ThreadsDatabase::global_future(cx)) - .await - .unwrap(); + let db = cx.update(|cx| ThreadsDatabase::connect(cx)).await.unwrap(); let threads = db.list_threads().await.unwrap(); assert_eq!(threads.len(), 1); let thread = db diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index e936c87643..d0fda4e020 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -10,7 +10,7 @@ use ui::App; use util::ResultExt as _; use crate::AgentServerCommand; -use acp_thread::{AcpThread, AgentConnection, AuthRequired}; +use acp_thread::{AcpThread, AcpThreadMetadata, AgentConnection, AuthRequired}; #[derive(Clone)] struct OldAcpClientDelegate { @@ -451,6 +451,10 @@ impl AgentConnection for AcpConnection { }) } + fn list_threads(&self, _cx: &mut App) -> Task>> { + Task::ready(Ok(Vec::default())) + } + fn auth_methods(&self) -> &[acp::AuthMethod] { &[] } diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index 36511e4644..b412cc34c0 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -11,7 +11,7 @@ use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use crate::{AgentServerCommand, acp::UnsupportedVersion}; -use acp_thread::{AcpThread, AgentConnection, AuthRequired}; +use acp_thread::{AcpThread, AcpThreadMetadata, AgentConnection, AuthRequired}; pub struct AcpConnection { server_name: &'static str, @@ -169,6 +169,10 @@ impl AgentConnection for AcpConnection { }) } + fn list_threads(&self, _cx: &mut App) -> Task>> { + Task::ready(Ok(Vec::default())) + } + fn prompt( &self, _id: Option, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index e1cc709289..291d27fd6a 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -30,7 +30,7 @@ use util::{ResultExt, debug_panic}; use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::tools::ClaudeTool; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::{AcpThread, AgentConnection}; +use acp_thread::{AcpThread, AcpThreadMetadata, AgentConnection}; #[derive(Clone)] pub struct ClaudeCode; @@ -209,6 +209,10 @@ impl AgentConnection for ClaudeAgentConnection { Task::ready(Err(anyhow!("Authentication not supported"))) } + fn list_threads(&self, _cx: &mut App) -> Task>> { + Task::ready(Ok(Vec::default())) + } + fn prompt( &self, _id: Option, diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 87af75f046..a459c36f81 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -3583,7 +3583,7 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { #[cfg(test)] pub(crate) mod tests { - use acp_thread::StubAgentConnection; + use acp_thread::{AcpThreadMetadata, StubAgentConnection}; use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol::SessionId; use editor::EditorSettings; @@ -3819,6 +3819,10 @@ pub(crate) mod tests { unimplemented!() } + fn list_threads(&self, _cx: &mut App) -> Task>> { + Task::ready(Ok(vec![])) + } + fn prompt( &self, _id: Option, diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 4f5f022593..a57e5b0563 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -9,6 +9,7 @@ mod context_picker; mod context_server_configuration; mod context_strip; mod debug; +mod history_store; mod inline_assistant; mod inline_prompt_editor; mod language_model_selector; diff --git a/crates/agent_ui/src/thread_history.rs b/crates/agent_ui/src/thread_history.rs index b8d1db88d6..04ccec0975 100644 --- a/crates/agent_ui/src/thread_history.rs +++ b/crates/agent_ui/src/thread_history.rs @@ -1,5 +1,5 @@ +use crate::history_store::{HistoryEntry, HistoryStore}; use crate::{AgentPanel, RemoveSelectedThread}; -use agent::history_store::{HistoryEntry, HistoryStore}; use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; use editor::{Editor, EditorEvent}; use fuzzy::{StringMatch, StringMatchCandidate}; From 296e3fcf69b51d6557f644e977c4f221366aeb79 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 15 Aug 2025 15:25:56 -0600 Subject: [PATCH 04/25] TEMP --- crates/agent_ui/src/agent_ui.rs | 2 +- crates/agent_ui/src/thread_history.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index a57e5b0563..946d24d9ae 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -9,7 +9,6 @@ mod context_picker; mod context_server_configuration; mod context_strip; mod debug; -mod history_store; mod inline_assistant; mod inline_prompt_editor; mod language_model_selector; @@ -22,6 +21,7 @@ mod terminal_codegen; mod terminal_inline_assistant; mod text_thread_editor; mod thread_history; +mod thread_history2; mod tool_compatibility; mod ui; diff --git a/crates/agent_ui/src/thread_history.rs b/crates/agent_ui/src/thread_history.rs index 04ccec0975..b8d1db88d6 100644 --- a/crates/agent_ui/src/thread_history.rs +++ b/crates/agent_ui/src/thread_history.rs @@ -1,5 +1,5 @@ -use crate::history_store::{HistoryEntry, HistoryStore}; use crate::{AgentPanel, RemoveSelectedThread}; +use agent::history_store::{HistoryEntry, HistoryStore}; use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; use editor::{Editor, EditorEvent}; use fuzzy::{StringMatch, StringMatchCandidate}; From e72f6f99c88ac3ca389d49d73b02136e2e4b1916 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 15 Aug 2025 17:27:57 -0600 Subject: [PATCH 05/25] Render history from the native agent --- Cargo.lock | 1 + crates/acp_thread/src/acp_thread.rs | 8 +- crates/acp_thread/src/connection.rs | 9 +- crates/agent/src/history_store.rs | 2 +- crates/agent2/Cargo.toml | 1 + crates/agent2/src/agent.rs | 50 +- crates/agent2/src/agent2.rs | 1 + crates/agent2/src/history_store.rs | 141 ++++ crates/agent2/src/native_agent_server.rs | 9 +- crates/agent_servers/src/acp.rs | 6 +- crates/agent_servers/src/acp/v0.rs | 12 +- crates/agent_servers/src/acp/v1.rs | 12 +- crates/agent_servers/src/agent_servers.rs | 4 +- crates/agent_servers/src/claude.rs | 12 +- crates/agent_servers/src/gemini.rs | 6 +- crates/agent_ui/src/acp.rs | 2 + crates/agent_ui/src/acp/thread_history.rs | 944 ++++++++++++++++++++++ crates/agent_ui/src/acp/thread_view.rs | 8 +- crates/agent_ui/src/agent_panel.rs | 24 +- crates/agent_ui/src/agent_ui.rs | 1 - 20 files changed, 1184 insertions(+), 69 deletions(-) create mode 100644 crates/agent2/src/history_store.rs create mode 100644 crates/agent_ui/src/acp/thread_history.rs diff --git a/Cargo.lock b/Cargo.lock index a20579cce3..910aefa034 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -197,6 +197,7 @@ dependencies = [ "agent_servers", "agent_settings", "anyhow", + "assistant_context", "assistant_tool", "assistant_tools", "chrono", diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d8b0d5805a..59a491a5b9 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -634,8 +634,12 @@ impl PlanEntry { } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct AgentServerName(pub SharedString); + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AcpThreadMetadata { + pub agent: AgentServerName, pub id: acp::SessionId, pub title: SharedString, pub updated_at: DateTime, @@ -2309,10 +2313,6 @@ mod tests { Task::ready(Ok(thread)) } - fn list_threads(&self, _: &mut App) -> Task>> { - unimplemented!() - } - fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task> { if self.auth_methods().iter().any(|m| m.id == method) { Task::ready(Ok(())) diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index bc31ad0fe7..dee18378ef 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -2,6 +2,7 @@ use crate::{AcpThread, AcpThreadMetadata}; use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; +use futures::channel::mpsc::UnboundedReceiver; use gpui::{Entity, SharedString, Task}; use project::Project; use serde::{Deserialize, Serialize}; @@ -26,7 +27,9 @@ pub trait AgentConnection { cx: &mut App, ) -> Task>>; - fn list_threads(&self, _cx: &mut App) -> Task>>; + fn list_threads(&self, _cx: &mut App) -> Option>> { + return None; + } fn auth_methods(&self) -> &[acp::AuthMethod]; @@ -266,10 +269,6 @@ mod test_support { unimplemented!() } - fn list_threads(&self, _: &mut App) -> Task>> { - unimplemented!() - } - fn prompt( &self, _id: Option, diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index eb39c3e454..4f2668384f 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -62,7 +62,7 @@ enum SerializedRecentOpen { pub struct HistoryStore { thread_store: Entity, - context_store: Entity, + pub context_store: Entity, recently_opened_entries: VecDeque, _subscriptions: Vec, _save_recently_opened_entries_task: Task<()>, diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index de1b1f74ec..ed487e2e22 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -61,6 +61,7 @@ web_search.workspace = true which.workspace = true workspace-hack.workspace = true zstd.workspace = true +assistant_context.workspace = true [dev-dependencies] agent = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 2bcc24d00c..be8428afc5 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,4 +1,5 @@ use crate::ThreadsDatabase; +use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME; use crate::{ AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, @@ -11,9 +12,9 @@ use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; use collections::{HashSet, IndexMap}; use fs::Fs; -use futures::channel::mpsc; +use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; use futures::future::Shared; -use futures::{StreamExt, future}; +use futures::{SinkExt, StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; @@ -169,6 +170,7 @@ pub struct NativeAgent { project: Entity, prompt_store: Option>, thread_database: Shared, Arc>>>, + history_listeners: Vec>>, fs: Arc, _subscriptions: Vec, } @@ -217,6 +219,7 @@ impl NativeAgent { project, prompt_store, fs, + history_listeners: Vec::new(), _subscriptions: subscriptions, } }) @@ -755,21 +758,36 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Task::ready(Ok(())) } - fn list_threads(&self, cx: &mut App) -> Task>> { - let database = self.0.read(cx).thread_database.clone(); - cx.background_executor().spawn(async move { - let database = database.await.map_err(|e| anyhow!(e))?; - let results = database.list_threads().await?; + fn list_threads(&self, cx: &mut App) -> Option>> { + dbg!("listing!"); + let (mut tx, rx) = futures::channel::mpsc::unbounded(); + let database = self.0.update(cx, |this, _| { + this.history_listeners.push(tx.clone()); + this.thread_database.clone() + }); + cx.background_executor() + .spawn(async move { + dbg!("listing!"); + let database = database.await.map_err(|e| anyhow!(e))?; + let results = database.list_threads().await?; - Ok(results - .into_iter() - .map(|thread| AcpThreadMetadata { - id: thread.id, - title: thread.title, - updated_at: thread.updated_at, - }) - .collect()) - }) + dbg!(&results); + tx.send( + results + .into_iter() + .map(|thread| AcpThreadMetadata { + agent: NATIVE_AGENT_SERVER_NAME.clone(), + id: thread.id, + title: thread.title, + updated_at: thread.updated_at, + }) + .collect(), + ) + .await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + Some(rx) } fn model_selector(&self) -> Option> { diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index 31e18e3769..1813fe1880 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,5 +1,6 @@ mod agent; mod db; +pub mod history_store; mod native_agent_server; mod templates; mod thread; diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs new file mode 100644 index 0000000000..f3ab3a0a6d --- /dev/null +++ b/crates/agent2/src/history_store.rs @@ -0,0 +1,141 @@ +use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName}; +use agent::{ThreadId, thread_store::ThreadStore}; +use agent_client_protocol as acp; +use anyhow::{Context as _, Result}; +use assistant_context::SavedContextMetadata; +use chrono::{DateTime, Utc}; +use collections::HashMap; +use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*}; +use itertools::Itertools; +use paths::contexts_dir; +use serde::{Deserialize, Serialize}; +use smol::stream::StreamExt; +use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration}; +use util::ResultExt as _; + +const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; +const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json"; +const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50); + +#[derive(Clone, Debug)] +pub enum HistoryEntry { + Thread(AcpThreadMetadata), + Context(SavedContextMetadata), +} + +impl HistoryEntry { + pub fn updated_at(&self) -> DateTime { + match self { + HistoryEntry::Thread(thread) => thread.updated_at, + HistoryEntry::Context(context) => context.mtime.to_utc(), + } + } + + pub fn id(&self) -> HistoryEntryId { + match self { + HistoryEntry::Thread(thread) => { + HistoryEntryId::Thread(thread.agent.clone(), thread.id.clone()) + } + HistoryEntry::Context(context) => HistoryEntryId::Context(context.path.clone()), + } + } + + pub fn title(&self) -> &SharedString { + match self { + HistoryEntry::Thread(thread) => &thread.title, + HistoryEntry::Context(context) => &context.title, + } + } +} + +/// Generic identifier for a history entry. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum HistoryEntryId { + Thread(AgentServerName, acp::SessionId), + Context(Arc), +} + +#[derive(Serialize, Deserialize)] +enum SerializedRecentOpen { + Thread(String), + ContextName(String), + /// Old format which stores the full path + Context(String), +} + +pub struct AgentHistory { + entries: HashMap, + _task: Task>, +} + +pub struct HistoryStore { + agents: HashMap, +} + +impl HistoryStore { + pub fn new(cx: &mut Context) -> Self { + Self { + agents: HashMap::default(), + } + } + + pub fn register_agent( + &mut self, + agent_name: AgentServerName, + connection: &dyn AgentConnection, + cx: &mut Context, + ) { + let Some(mut history) = connection.list_threads(cx) else { + return; + }; + let task = cx.spawn(async move |this, cx| { + while let Some(updated_history) = history.next().await { + dbg!(&updated_history); + this.update(cx, |this, cx| { + for entry in updated_history { + let agent = this + .agents + .get_mut(&entry.agent) + .context("agent not found")?; + agent.entries.insert(entry.id.clone(), entry); + } + cx.notify(); + anyhow::Ok(()) + })?? + } + Ok(()) + }); + self.agents.insert( + agent_name, + AgentHistory { + entries: Default::default(), + _task: task, + }, + ); + } + + pub fn entries(&self, cx: &mut Context) -> Vec { + let mut history_entries = Vec::new(); + + #[cfg(debug_assertions)] + if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() { + return history_entries; + } + + history_entries.extend( + self.agents + .values() + .flat_map(|agent| agent.entries.values()) + .cloned() + .map(HistoryEntry::Thread), + ); + // todo!() include the text threads in here. + + history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at())); + dbg!(history_entries) + } + + pub fn recent_entries(&self, limit: usize, cx: &mut Context) -> Vec { + self.entries(cx).into_iter().take(limit).collect() + } +} diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs index cadd88a846..c8ff38e893 100644 --- a/crates/agent2/src/native_agent_server.rs +++ b/crates/agent2/src/native_agent_server.rs @@ -1,11 +1,13 @@ use std::{path::Path, rc::Rc, sync::Arc}; +use acp_thread::AgentServerName; use agent_servers::AgentServer; use anyhow::Result; use fs::Fs; use gpui::{App, Entity, Task}; use project::Project; use prompt_store::PromptStore; +use ui::SharedString; use crate::{NativeAgent, NativeAgentConnection, templates::Templates}; @@ -20,9 +22,12 @@ impl NativeAgentServer { } } +pub const NATIVE_AGENT_SERVER_NAME: AgentServerName = + AgentServerName(SharedString::new_static("Native Agent")); + impl AgentServer for NativeAgentServer { - fn name(&self) -> &'static str { - "Native Agent" + fn name(&self) -> AgentServerName { + NATIVE_AGENT_SERVER_NAME.clone() } fn empty_state_headline(&self) -> &'static str { diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 00e3e3df50..7891316925 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -1,7 +1,7 @@ use std::{path::Path, rc::Rc}; use crate::AgentServerCommand; -use acp_thread::AgentConnection; +use acp_thread::{AgentConnection, AgentServerName}; use anyhow::Result; use gpui::AsyncApp; use thiserror::Error; @@ -14,12 +14,12 @@ mod v1; pub struct UnsupportedVersion; pub async fn connect( - server_name: &'static str, + server_name: AgentServerName, command: AgentServerCommand, root_dir: &Path, cx: &mut AsyncApp, ) -> Result> { - let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await; + let conn = v1::AcpConnection::stdio(server_name.clone(), command.clone(), &root_dir, cx).await; match conn { Ok(conn) => Ok(Rc::new(conn) as _), diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index d0fda4e020..ffb3d5a72c 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -10,7 +10,7 @@ use ui::App; use util::ResultExt as _; use crate::AgentServerCommand; -use acp_thread::{AcpThread, AcpThreadMetadata, AgentConnection, AuthRequired}; +use acp_thread::{AcpThread, AgentConnection, AgentServerName, AuthRequired}; #[derive(Clone)] struct OldAcpClientDelegate { @@ -354,7 +354,7 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu } pub struct AcpConnection { - pub name: &'static str, + pub name: AgentServerName, pub connection: acp_old::AgentConnection, pub _child_status: Task>, pub current_thread: Rc>>, @@ -362,7 +362,7 @@ pub struct AcpConnection { impl AcpConnection { pub fn stdio( - name: &'static str, + name: AgentServerName, command: AgentServerCommand, root_dir: &Path, cx: &mut AsyncApp, @@ -443,7 +443,7 @@ impl AgentConnection for AcpConnection { cx.update(|cx| { let thread = cx.new(|cx| { let session_id = acp::SessionId("acp-old-no-id".into()); - AcpThread::new(self.name, self.clone(), project, session_id, cx) + AcpThread::new(self.name.0.clone(), self.clone(), project, session_id, cx) }); current_thread.replace(thread.downgrade()); thread @@ -451,10 +451,6 @@ impl AgentConnection for AcpConnection { }) } - fn list_threads(&self, _cx: &mut App) -> Task>> { - Task::ready(Ok(Vec::default())) - } - fn auth_methods(&self) -> &[acp::AuthMethod] { &[] } diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index b412cc34c0..7cf89bba64 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -11,10 +11,10 @@ use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use crate::{AgentServerCommand, acp::UnsupportedVersion}; -use acp_thread::{AcpThread, AcpThreadMetadata, AgentConnection, AuthRequired}; +use acp_thread::{AcpThread, AgentConnection, AgentServerName, AuthRequired}; pub struct AcpConnection { - server_name: &'static str, + server_name: AgentServerName, connection: Rc, sessions: Rc>>, auth_methods: Vec, @@ -29,7 +29,7 @@ const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1; impl AcpConnection { pub async fn stdio( - server_name: &'static str, + server_name: AgentServerName, command: AgentServerCommand, root_dir: &Path, cx: &mut AsyncApp, @@ -135,7 +135,7 @@ impl AgentConnection for AcpConnection { let thread = cx.new(|cx| { AcpThread::new( - self.server_name, + self.server_name.0.clone(), self.clone(), project, session_id.clone(), @@ -169,10 +169,6 @@ impl AgentConnection for AcpConnection { }) } - fn list_threads(&self, _cx: &mut App) -> Task>> { - Task::ready(Ok(Vec::default())) - } - fn prompt( &self, _id: Option, diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index b3b8a33170..0836af1ba9 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -10,7 +10,7 @@ pub use claude::*; pub use gemini::*; pub use settings::*; -use acp_thread::AgentConnection; +use acp_thread::{AgentConnection, AgentServerName}; use anyhow::Result; use collections::HashMap; use gpui::{App, AsyncApp, Entity, SharedString, Task}; @@ -30,7 +30,7 @@ pub fn init(cx: &mut App) { pub trait AgentServer: Send { fn logo(&self) -> ui::IconName; - fn name(&self) -> &'static str; + fn name(&self) -> AgentServerName; fn empty_state_headline(&self) -> &'static str; fn empty_state_message(&self) -> &'static str; diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 291d27fd6a..c0d0b132a2 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -30,18 +30,18 @@ use util::{ResultExt, debug_panic}; use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig}; use crate::claude::tools::ClaudeTool; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::{AcpThread, AcpThreadMetadata, AgentConnection}; +use acp_thread::{AcpThread, AgentConnection, AgentServerName}; #[derive(Clone)] pub struct ClaudeCode; impl AgentServer for ClaudeCode { - fn name(&self) -> &'static str { - "Claude Code" + fn name(&self) -> AgentServerName { + AgentServerName("Claude Code".into()) } fn empty_state_headline(&self) -> &'static str { - self.name() + "Claude Code" } fn empty_state_message(&self) -> &'static str { @@ -209,10 +209,6 @@ impl AgentConnection for ClaudeAgentConnection { Task::ready(Err(anyhow!("Authentication not supported"))) } - fn list_threads(&self, _cx: &mut App) -> Task>> { - Task::ready(Ok(Vec::default())) - } - fn prompt( &self, _id: Option, diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index ad883f6da8..ab428fe5b3 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -2,7 +2,7 @@ use std::path::Path; use std::rc::Rc; use crate::{AgentServer, AgentServerCommand}; -use acp_thread::{AgentConnection, LoadError}; +use acp_thread::{AgentConnection, AgentServerName, LoadError}; use anyhow::Result; use gpui::{Entity, Task}; use project::Project; @@ -17,8 +17,8 @@ pub struct Gemini; const ACP_ARG: &str = "--experimental-acp"; impl AgentServer for Gemini { - fn name(&self) -> &'static str { - "Gemini" + fn name(&self) -> AgentServerName { + AgentServerName("Gemini".into()) } fn empty_state_headline(&self) -> &'static str { diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs index 831d296eeb..c3fe90760e 100644 --- a/crates/agent_ui/src/acp.rs +++ b/crates/agent_ui/src/acp.rs @@ -3,8 +3,10 @@ mod entry_view_state; mod message_editor; mod model_selector; mod model_selector_popover; +mod thread_history; mod thread_view; pub use model_selector::AcpModelSelector; pub use model_selector_popover::AcpModelSelectorPopover; +pub use thread_history::AcpThreadHistory; pub use thread_view::AcpThreadView; diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs new file mode 100644 index 0000000000..d1ebf59d75 --- /dev/null +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -0,0 +1,944 @@ +use crate::{AgentPanel, RemoveSelectedThread}; +use agent_servers::AgentServer; +use agent2::{ + NativeAgentServer, + history_store::{HistoryEntry, HistoryStore}, +}; +use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; +use editor::{Editor, EditorEvent}; +use fuzzy::{StringMatch, StringMatchCandidate}; +use gpui::{ + App, ClickEvent, Empty, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, + UniformListScrollHandle, WeakEntity, Window, uniform_list, +}; +use project::Project; +use std::{fmt::Display, ops::Range, sync::Arc}; +use time::{OffsetDateTime, UtcOffset}; +use ui::{ + HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState, + Tooltip, prelude::*, +}; +use util::ResultExt; + +pub struct AcpThreadHistory { + agent_panel: WeakEntity, + history_store: Entity, + scroll_handle: UniformListScrollHandle, + selected_index: usize, + hovered_index: Option, + search_editor: Entity, + all_entries: Arc>, + // When the search is empty, we display date separators between history entries + // This vector contains an enum of either a separator or an actual entry + separated_items: Vec, + // Maps entry indexes to list item indexes + separated_item_indexes: Vec, + _separated_items_task: Option>, + search_state: SearchState, + scrollbar_visibility: bool, + scrollbar_state: ScrollbarState, + _subscriptions: Vec, +} + +enum SearchState { + Empty, + Searching { + query: SharedString, + _task: Task<()>, + }, + Searched { + query: SharedString, + matches: Vec, + }, +} + +enum ListItemType { + BucketSeparator(TimeBucket), + Entry { + index: usize, + format: EntryTimeFormat, + }, +} + +impl ListItemType { + fn entry_index(&self) -> Option { + match self { + ListItemType::BucketSeparator(_) => None, + ListItemType::Entry { index, .. } => Some(*index), + } + } +} + +impl AcpThreadHistory { + pub(crate) fn new( + agent_panel: WeakEntity, + project: &Entity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let history_store = cx.new(|cx| agent2::history_store::HistoryStore::new(cx)); + + let agent = NativeAgentServer::new(project.read(cx).fs().clone()); + + let root_dir = project + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).abs_path()) + .unwrap_or_else(|| paths::home_dir().as_path().into()); + + // todo!() reuse this connection for sending messages + let connect = agent.connect(&root_dir, project, cx); + cx.spawn(async move |this, cx| { + let connection = connect.await?; + this.update(cx, |this, cx| { + this.history_store.update(cx, |this, cx| { + this.register_agent(agent.name(), connection.as_ref(), cx) + }) + })?; + // todo!() we must keep it alive + std::mem::forget(connection); + anyhow::Ok(()) + }) + .detach(); + + dbg!("hello!"); + let search_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text("Search threads...", cx); + editor + }); + + let search_editor_subscription = + cx.subscribe(&search_editor, |this, search_editor, event, cx| { + if let EditorEvent::BufferEdited = event { + let query = search_editor.read(cx).text(cx); + this.search(query.into(), cx); + } + }); + + let history_store_subscription = cx.observe(&history_store, |this, _, cx| { + this.update_all_entries(cx); + }); + + let scroll_handle = UniformListScrollHandle::default(); + let scrollbar_state = ScrollbarState::new(scroll_handle.clone()); + + let mut this = Self { + agent_panel, + history_store, + scroll_handle, + selected_index: 0, + hovered_index: None, + search_state: SearchState::Empty, + all_entries: Default::default(), + separated_items: Default::default(), + separated_item_indexes: Default::default(), + search_editor, + scrollbar_visibility: true, + scrollbar_state, + _subscriptions: vec![search_editor_subscription, history_store_subscription], + _separated_items_task: None, + }; + this.update_all_entries(cx); + this + } + + fn update_all_entries(&mut self, cx: &mut Context) { + let new_entries: Arc> = self + .history_store + .update(cx, |store, cx| store.entries(cx)) + .into(); + + self._separated_items_task.take(); + + let mut items = Vec::with_capacity(new_entries.len() + 1); + let mut indexes = Vec::with_capacity(new_entries.len() + 1); + + let bg_task = cx.background_spawn(async move { + let mut bucket = None; + let today = Local::now().naive_local().date(); + + for (index, entry) in new_entries.iter().enumerate() { + let entry_date = entry + .updated_at() + .with_timezone(&Local) + .naive_local() + .date(); + let entry_bucket = TimeBucket::from_dates(today, entry_date); + + if Some(entry_bucket) != bucket { + bucket = Some(entry_bucket); + items.push(ListItemType::BucketSeparator(entry_bucket)); + } + + indexes.push(items.len() as u32); + items.push(ListItemType::Entry { + index, + format: entry_bucket.into(), + }); + } + (new_entries, items, indexes) + }); + + let task = cx.spawn(async move |this, cx| { + let (new_entries, items, indexes) = bg_task.await; + this.update(cx, |this, cx| { + let previously_selected_entry = + this.all_entries.get(this.selected_index).map(|e| e.id()); + + this.all_entries = new_entries; + this.separated_items = items; + this.separated_item_indexes = indexes; + + match &this.search_state { + SearchState::Empty => { + if this.selected_index >= this.all_entries.len() { + this.set_selected_entry_index( + this.all_entries.len().saturating_sub(1), + cx, + ); + } else if let Some(prev_id) = previously_selected_entry { + if let Some(new_ix) = this + .all_entries + .iter() + .position(|probe| probe.id() == prev_id) + { + this.set_selected_entry_index(new_ix, cx); + } + } + } + SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => { + this.search(query.clone(), cx); + } + } + + cx.notify(); + }) + .log_err(); + }); + self._separated_items_task = Some(task); + } + + fn search(&mut self, query: SharedString, cx: &mut Context) { + if query.is_empty() { + self.search_state = SearchState::Empty; + cx.notify(); + return; + } + + let all_entries = self.all_entries.clone(); + + let fuzzy_search_task = cx.background_spawn({ + let query = query.clone(); + let executor = cx.background_executor().clone(); + async move { + let mut candidates = Vec::with_capacity(all_entries.len()); + + for (idx, entry) in all_entries.iter().enumerate() { + match entry { + HistoryEntry::Thread(thread) => { + candidates.push(StringMatchCandidate::new(idx, &thread.title)); + } + HistoryEntry::Context(context) => { + candidates.push(StringMatchCandidate::new(idx, &context.title)); + } + } + } + + const MAX_MATCHES: usize = 100; + + fuzzy::match_strings( + &candidates, + &query, + false, + true, + MAX_MATCHES, + &Default::default(), + executor, + ) + .await + } + }); + + let task = cx.spawn({ + let query = query.clone(); + async move |this, cx| { + let matches = fuzzy_search_task.await; + + this.update(cx, |this, cx| { + let SearchState::Searching { + query: current_query, + _task, + } = &this.search_state + else { + return; + }; + + if &query == current_query { + this.search_state = SearchState::Searched { + query: query.clone(), + matches, + }; + + this.set_selected_entry_index(0, cx); + cx.notify(); + }; + }) + .log_err(); + } + }); + + self.search_state = SearchState::Searching { query, _task: task }; + cx.notify(); + } + + fn matched_count(&self) -> usize { + match &self.search_state { + SearchState::Empty => self.all_entries.len(), + SearchState::Searching { .. } => 0, + SearchState::Searched { matches, .. } => matches.len(), + } + } + + fn list_item_count(&self) -> usize { + match &self.search_state { + SearchState::Empty => self.separated_items.len(), + SearchState::Searching { .. } => 0, + SearchState::Searched { matches, .. } => matches.len(), + } + } + + fn search_produced_no_matches(&self) -> bool { + match &self.search_state { + SearchState::Empty => false, + SearchState::Searching { .. } => false, + SearchState::Searched { matches, .. } => matches.is_empty(), + } + } + + fn get_match(&self, ix: usize) -> Option<&HistoryEntry> { + match &self.search_state { + SearchState::Empty => self.all_entries.get(ix), + SearchState::Searching { .. } => None, + SearchState::Searched { matches, .. } => matches + .get(ix) + .and_then(|m| self.all_entries.get(m.candidate_id)), + } + } + + pub fn select_previous( + &mut self, + _: &menu::SelectPrevious, + _window: &mut Window, + cx: &mut Context, + ) { + let count = self.matched_count(); + if count > 0 { + if self.selected_index == 0 { + self.set_selected_entry_index(count - 1, cx); + } else { + self.set_selected_entry_index(self.selected_index - 1, cx); + } + } + } + + pub fn select_next( + &mut self, + _: &menu::SelectNext, + _window: &mut Window, + cx: &mut Context, + ) { + let count = self.matched_count(); + if count > 0 { + if self.selected_index == count - 1 { + self.set_selected_entry_index(0, cx); + } else { + self.set_selected_entry_index(self.selected_index + 1, cx); + } + } + } + + fn select_first( + &mut self, + _: &menu::SelectFirst, + _window: &mut Window, + cx: &mut Context, + ) { + let count = self.matched_count(); + if count > 0 { + self.set_selected_entry_index(0, cx); + } + } + + fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context) { + let count = self.matched_count(); + if count > 0 { + self.set_selected_entry_index(count - 1, cx); + } + } + + fn set_selected_entry_index(&mut self, entry_index: usize, cx: &mut Context) { + self.selected_index = entry_index; + + let scroll_ix = match self.search_state { + SearchState::Empty | SearchState::Searching { .. } => self + .separated_item_indexes + .get(entry_index) + .map(|ix| *ix as usize) + .unwrap_or(entry_index + 1), + SearchState::Searched { .. } => entry_index, + }; + + self.scroll_handle + .scroll_to_item(scroll_ix, ScrollStrategy::Top); + + cx.notify(); + } + + fn render_scrollbar(&self, cx: &mut Context) -> Option> { + if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) { + return None; + } + + Some( + div() + .occlude() + .id("thread-history-scroll") + .h_full() + .bg(cx.theme().colors().panel_background.opacity(0.8)) + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .absolute() + .right_1() + .top_0() + .bottom_0() + .w_4() + .pl_1() + .cursor_default() + .on_mouse_move(cx.listener(|_, _, _window, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _window, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _window, cx| { + cx.stop_propagation(); + }) + .on_scroll_wheel(cx.listener(|_, _, _window, cx| { + cx.notify(); + })) + .children(Scrollbar::vertical(self.scrollbar_state.clone())), + ) + } + + fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + if let Some(entry) = self.get_match(self.selected_index) { + let task_result = match entry { + HistoryEntry::Thread(thread) => { + self.agent_panel.update(cx, move |agent_panel, cx| todo!()) + } + HistoryEntry::Context(context) => { + self.agent_panel.update(cx, move |agent_panel, cx| { + agent_panel.open_saved_prompt_editor(context.path.clone(), window, cx) + }) + } + }; + + if let Some(task) = task_result.log_err() { + task.detach_and_log_err(cx); + }; + + cx.notify(); + } + } + + fn remove_selected_thread( + &mut self, + _: &RemoveSelectedThread, + _window: &mut Window, + cx: &mut Context, + ) { + if let Some(entry) = self.get_match(self.selected_index) { + let task_result = match entry { + HistoryEntry::Thread(thread) => todo!(), + HistoryEntry::Context(context) => self + .agent_panel + .update(cx, |this, cx| this.delete_context(context.path.clone(), cx)), + }; + + if let Some(task) = task_result.log_err() { + task.detach_and_log_err(cx); + }; + + cx.notify(); + } + } + + fn list_items( + &mut self, + range: Range, + _window: &mut Window, + cx: &mut Context, + ) -> Vec { + let range_start = range.start; + + match &self.search_state { + SearchState::Empty => self + .separated_items + .get(range) + .iter() + .flat_map(|items| { + items + .iter() + .map(|item| self.render_list_item(item.entry_index(), item, vec![], cx)) + }) + .collect(), + SearchState::Searched { matches, .. } => matches[range] + .iter() + .enumerate() + .map(|(ix, m)| { + self.render_list_item( + Some(range_start + ix), + &ListItemType::Entry { + index: m.candidate_id, + format: EntryTimeFormat::DateAndTime, + }, + m.positions.clone(), + cx, + ) + }) + .collect(), + SearchState::Searching { .. } => { + vec![] + } + } + } + + fn render_list_item( + &self, + list_entry_ix: Option, + item: &ListItemType, + highlight_positions: Vec, + cx: &Context, + ) -> AnyElement { + match item { + ListItemType::Entry { index, format } => match self.all_entries.get(*index) { + Some(entry) => h_flex() + .w_full() + .pb_1() + .child( + HistoryEntryElement::new(entry.clone(), self.agent_panel.clone()) + .highlight_positions(highlight_positions) + .timestamp_format(*format) + .selected(list_entry_ix == Some(self.selected_index)) + .hovered(list_entry_ix == self.hovered_index) + .on_hover(cx.listener(move |this, is_hovered, _window, cx| { + if *is_hovered { + this.hovered_index = list_entry_ix; + } else if this.hovered_index == list_entry_ix { + this.hovered_index = None; + } + + cx.notify(); + })) + .into_any_element(), + ) + .into_any(), + None => Empty.into_any_element(), + }, + ListItemType::BucketSeparator(bucket) => div() + .px(DynamicSpacing::Base06.rems(cx)) + .pt_2() + .pb_1() + .child( + Label::new(bucket.to_string()) + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .into_any_element(), + } + } +} + +impl Focusable for AcpThreadHistory { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.search_editor.focus_handle(cx) + } +} + +impl Render for AcpThreadHistory { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .key_context("ThreadHistory") + .size_full() + .on_action(cx.listener(Self::select_previous)) + .on_action(cx.listener(Self::select_next)) + .on_action(cx.listener(Self::select_first)) + .on_action(cx.listener(Self::select_last)) + .on_action(cx.listener(Self::confirm)) + .on_action(cx.listener(Self::remove_selected_thread)) + .when(!self.all_entries.is_empty(), |parent| { + parent.child( + h_flex() + .h(px(41.)) // Match the toolbar perfectly + .w_full() + .py_1() + .px_2() + .gap_2() + .justify_between() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + Icon::new(IconName::MagnifyingGlass) + .color(Color::Muted) + .size(IconSize::Small), + ) + .child(self.search_editor.clone()), + ) + }) + .child({ + let view = v_flex() + .id("list-container") + .relative() + .overflow_hidden() + .flex_grow(); + + if self.all_entries.is_empty() { + view.justify_center() + .child( + h_flex().w_full().justify_center().child( + Label::new("You don't have any past threads yet.") + .size(LabelSize::Small), + ), + ) + } else if self.search_produced_no_matches() { + view.justify_center().child( + h_flex().w_full().justify_center().child( + Label::new("No threads match your search.").size(LabelSize::Small), + ), + ) + } else { + view.pr_5() + .child( + uniform_list( + "thread-history", + self.list_item_count(), + cx.processor(|this, range: Range, window, cx| { + this.list_items(range, window, cx) + }), + ) + .p_1() + .track_scroll(self.scroll_handle.clone()) + .flex_grow(), + ) + .when_some(self.render_scrollbar(cx), |div, scrollbar| { + div.child(scrollbar) + }) + } + }) + } +} + +#[derive(IntoElement)] +pub struct HistoryEntryElement { + entry: HistoryEntry, + agent_panel: WeakEntity, + selected: bool, + hovered: bool, + highlight_positions: Vec, + timestamp_format: EntryTimeFormat, + on_hover: Box, +} + +impl HistoryEntryElement { + pub fn new(entry: HistoryEntry, agent_panel: WeakEntity) -> Self { + Self { + entry, + agent_panel, + selected: false, + hovered: false, + highlight_positions: vec![], + timestamp_format: EntryTimeFormat::DateAndTime, + on_hover: Box::new(|_, _, _| {}), + } + } + + pub fn selected(mut self, selected: bool) -> Self { + self.selected = selected; + self + } + + pub fn hovered(mut self, hovered: bool) -> Self { + self.hovered = hovered; + self + } + + pub fn highlight_positions(mut self, positions: Vec) -> Self { + self.highlight_positions = positions; + self + } + + pub fn on_hover(mut self, on_hover: impl Fn(&bool, &mut Window, &mut App) + 'static) -> Self { + self.on_hover = Box::new(on_hover); + self + } + + pub fn timestamp_format(mut self, format: EntryTimeFormat) -> Self { + self.timestamp_format = format; + self + } +} + +impl RenderOnce for HistoryEntryElement { + fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { + let (id, summary, timestamp) = match &self.entry { + HistoryEntry::Thread(thread) => ( + thread.id.to_string(), + thread.title.clone(), + thread.updated_at.timestamp(), + ), + HistoryEntry::Context(context) => ( + context.path.to_string_lossy().to_string(), + context.title.clone(), + context.mtime.timestamp(), + ), + }; + + let thread_timestamp = + self.timestamp_format + .format_timestamp(&self.agent_panel, timestamp, cx); + + ListItem::new(SharedString::from(id)) + .rounded() + .toggle_state(self.selected) + .spacing(ListItemSpacing::Sparse) + .start_slot( + h_flex() + .w_full() + .gap_2() + .justify_between() + .child( + HighlightedLabel::new(summary, self.highlight_positions) + .size(LabelSize::Small) + .truncate(), + ) + .child( + Label::new(thread_timestamp) + .color(Color::Muted) + .size(LabelSize::XSmall), + ), + ) + .on_hover(self.on_hover) + .end_slot::(if self.hovered || self.selected { + Some( + IconButton::new("delete", IconName::Trash) + .shape(IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .tooltip(move |window, cx| { + Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx) + }) + .on_click({ + let agent_panel = self.agent_panel.clone(); + + let f: Box = + match &self.entry { + HistoryEntry::Thread(thread) => { + let id = thread.id.clone(); + + Box::new(move |_event, _window, cx| { + agent_panel + .update(cx, |agent_panel, cx| { + todo!() + // this.delete_thread(&id, cx) + // .detach_and_log_err(cx); + }) + .ok(); + }) + } + HistoryEntry::Context(context) => { + let path = context.path.clone(); + + Box::new(move |_event, _window, cx| { + agent_panel + .update(cx, |this, cx| { + this.delete_context(path.clone(), cx) + .detach_and_log_err(cx); + }) + .ok(); + }) + } + }; + f + }), + ) + } else { + None + }) + .on_click({ + let agent_panel = self.agent_panel.clone(); + + let f: Box = match &self.entry + { + HistoryEntry::Thread(thread) => { + let id = thread.id.clone(); + Box::new(move |_event, window, cx| { + agent_panel + .update(cx, |agent_panel, cx| { + // todo!() + }) + .ok(); + }) + } + HistoryEntry::Context(context) => { + let path = context.path.clone(); + Box::new(move |_event, window, cx| { + agent_panel + .update(cx, |this, cx| { + this.open_saved_prompt_editor(path.clone(), window, cx) + .detach_and_log_err(cx); + }) + .ok(); + }) + } + }; + f + }) + } +} + +#[derive(Clone, Copy)] +pub enum EntryTimeFormat { + DateAndTime, + TimeOnly, +} + +impl EntryTimeFormat { + fn format_timestamp( + &self, + agent_panel: &WeakEntity, + timestamp: i64, + cx: &App, + ) -> String { + let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap(); + let timezone = agent_panel + .read_with(cx, |this, _cx| this.local_timezone()) + .unwrap_or(UtcOffset::UTC); + + match &self { + EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp( + timestamp, + OffsetDateTime::now_utc(), + timezone, + time_format::TimestampFormat::EnhancedAbsolute, + ), + EntryTimeFormat::TimeOnly => time_format::format_time(timestamp), + } + } +} + +impl From for EntryTimeFormat { + fn from(bucket: TimeBucket) -> Self { + match bucket { + TimeBucket::Today => EntryTimeFormat::TimeOnly, + TimeBucket::Yesterday => EntryTimeFormat::TimeOnly, + TimeBucket::ThisWeek => EntryTimeFormat::DateAndTime, + TimeBucket::PastWeek => EntryTimeFormat::DateAndTime, + TimeBucket::All => EntryTimeFormat::DateAndTime, + } + } +} + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +enum TimeBucket { + Today, + Yesterday, + ThisWeek, + PastWeek, + All, +} + +impl TimeBucket { + fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self { + if date == reference { + return TimeBucket::Today; + } + + if date == reference - TimeDelta::days(1) { + return TimeBucket::Yesterday; + } + + let week = date.iso_week(); + + if reference.iso_week() == week { + return TimeBucket::ThisWeek; + } + + let last_week = (reference - TimeDelta::days(7)).iso_week(); + + if week == last_week { + return TimeBucket::PastWeek; + } + + TimeBucket::All + } +} + +impl Display for TimeBucket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TimeBucket::Today => write!(f, "Today"), + TimeBucket::Yesterday => write!(f, "Yesterday"), + TimeBucket::ThisWeek => write!(f, "This Week"), + TimeBucket::PastWeek => write!(f, "Past Week"), + TimeBucket::All => write!(f, "All"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::NaiveDate; + + #[test] + fn test_time_bucket_from_dates() { + let today = NaiveDate::from_ymd_opt(2023, 1, 15).unwrap(); + + let date = today; + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Today); + + let date = NaiveDate::from_ymd_opt(2023, 1, 14).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Yesterday); + + let date = NaiveDate::from_ymd_opt(2023, 1, 13).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek); + + let date = NaiveDate::from_ymd_opt(2023, 1, 11).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek); + + let date = NaiveDate::from_ymd_opt(2023, 1, 8).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek); + + let date = NaiveDate::from_ymd_opt(2023, 1, 5).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek); + + // All: not in this week or last week + let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); + assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::All); + + // Test year boundary cases + let new_year = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap(); + + let date = NaiveDate::from_ymd_opt(2022, 12, 31).unwrap(); + assert_eq!( + TimeBucket::from_dates(new_year, date), + TimeBucket::Yesterday + ); + + let date = NaiveDate::from_ymd_opt(2022, 12, 28).unwrap(); + assert_eq!(TimeBucket::from_dates(new_year, date), TimeBucket::ThisWeek); + } +} diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index a459c36f81..af97f944e9 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -3583,7 +3583,7 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { #[cfg(test)] pub(crate) mod tests { - use acp_thread::{AcpThreadMetadata, StubAgentConnection}; + use acp_thread::{AgentServerName, StubAgentConnection}; use agent::{TextThreadStore, ThreadStore}; use agent_client_protocol::SessionId; use editor::EditorSettings; @@ -3764,7 +3764,7 @@ pub(crate) mod tests { unimplemented!() } - fn name(&self) -> &'static str { + fn name(&self) -> AgentServerName { unimplemented!() } @@ -3819,10 +3819,6 @@ pub(crate) mod tests { unimplemented!() } - fn list_threads(&self, _cx: &mut App) -> Task>> { - Task::ready(Ok(vec![])) - } - fn prompt( &self, _id: Option, diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 73915195f5..f54ea9b31a 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -5,10 +5,12 @@ use std::sync::Arc; use std::time::Duration; use agent_servers::AgentServer; +use agent2::NativeAgentServer; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use serde::{Deserialize, Serialize}; use crate::NewExternalAgentThread; +use crate::acp::AcpThreadHistory; use crate::agent_diff::AgentDiffThread; use crate::{ AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode, @@ -478,6 +480,7 @@ pub struct AgentPanel { previous_view: Option, history_store: Entity, history: Entity, + acp_history: Entity, hovered_recent_history_item: Option, new_thread_menu_handle: PopoverMenuHandle, agent_panel_menu_handle: PopoverMenuHandle, @@ -743,6 +746,9 @@ impl AgentPanel { ) }); + let acp_history = + cx.new(|cx| AcpThreadHistory::new(weak_self.clone(), &project, window, cx)); + Self { active_view, workspace, @@ -764,6 +770,7 @@ impl AgentPanel { previous_view: None, history_store: history_store.clone(), history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)), + acp_history, hovered_recent_history_item: None, new_thread_menu_handle: PopoverMenuHandle::default(), agent_panel_menu_handle: PopoverMenuHandle::default(), @@ -1652,7 +1659,14 @@ impl Focusable for AgentPanel { match &self.active_view { ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx), ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx), - ActiveView::History => self.history.focus_handle(cx), + ActiveView::History => { + if cx.has_flag::() { + self.acp_history.focus_handle(cx) + } else { + self.history.focus_handle(cx) + } + } + ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx), ActiveView::Configuration => { if let Some(configuration) = self.configuration.as_ref() { @@ -3499,7 +3513,13 @@ impl Render for AgentPanel { ActiveView::ExternalAgentThread { thread_view, .. } => parent .child(thread_view.clone()) .child(self.render_drag_target(cx)), - ActiveView::History => parent.child(self.history.clone()), + ActiveView::History => { + if cx.has_flag::() { + parent.child(self.acp_history.clone()) + } else { + parent.child(self.history.clone()) + } + } ActiveView::TextThread { context_editor, buffer_search_bar, diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 946d24d9ae..4f5f022593 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -21,7 +21,6 @@ mod terminal_codegen; mod terminal_inline_assistant; mod text_thread_editor; mod thread_history; -mod thread_history2; mod tool_compatibility; mod ui; From 1b793331b355313f12dc503cde04b7971b83e48c Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 15 Aug 2025 22:32:52 -0600 Subject: [PATCH 06/25] Tidy up acp thread view implementation --- crates/agent2/src/history_store.rs | 2 +- crates/agent_ui/src/acp/thread_history.rs | 374 ++++++++-------------- 2 files changed, 126 insertions(+), 250 deletions(-) diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs index f3ab3a0a6d..fb4f34f9c5 100644 --- a/crates/agent2/src/history_store.rs +++ b/crates/agent2/src/history_store.rs @@ -132,7 +132,7 @@ impl HistoryStore { // todo!() include the text threads in here. history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at())); - dbg!(history_entries) + history_entries } pub fn recent_entries(&self, limit: usize, cx: &mut Context) -> Vec { diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index d1ebf59d75..4bf409a2af 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -8,7 +8,7 @@ use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; use editor::{Editor, EditorEvent}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ - App, ClickEvent, Empty, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, + App, Empty, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, UniformListScrollHandle, WeakEntity, Window, uniform_list, }; use project::Project; @@ -37,6 +37,7 @@ pub struct AcpThreadHistory { search_state: SearchState, scrollbar_visibility: bool, scrollbar_state: ScrollbarState, + local_timezone: UtcOffset, _subscriptions: Vec, } @@ -60,15 +61,6 @@ enum ListItemType { }, } -impl ListItemType { - fn entry_index(&self) -> Option { - match self { - ListItemType::BucketSeparator(_) => None, - ListItemType::Entry { index, .. } => Some(*index), - } - } -} - impl AcpThreadHistory { pub(crate) fn new( agent_panel: WeakEntity, @@ -137,6 +129,10 @@ impl AcpThreadHistory { search_editor, scrollbar_visibility: true, scrollbar_state, + local_timezone: UtcOffset::from_whole_seconds( + chrono::Local::now().offset().local_minus_utc(), + ) + .unwrap(), _subscriptions: vec![search_editor_subscription, history_store_subscription], _separated_items_task: None, }; @@ -434,24 +430,29 @@ impl AcpThreadHistory { } fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - if let Some(entry) = self.get_match(self.selected_index) { - let task_result = match entry { - HistoryEntry::Thread(thread) => { - self.agent_panel.update(cx, move |agent_panel, cx| todo!()) - } - HistoryEntry::Context(context) => { - self.agent_panel.update(cx, move |agent_panel, cx| { - agent_panel.open_saved_prompt_editor(context.path.clone(), window, cx) - }) - } - }; + self.confirm_entry(self.selected_index, window, cx); + } - if let Some(task) = task_result.log_err() { - task.detach_and_log_err(cx); - }; + fn confirm_entry(&mut self, ix: usize, window: &mut Window, cx: &mut Context) { + let Some(entry) = self.get_match(ix) else { + return; + }; + let task_result = match entry { + HistoryEntry::Thread(thread) => { + self.agent_panel.update(cx, move |agent_panel, cx| todo!()) + } + HistoryEntry::Context(context) => { + self.agent_panel.update(cx, move |agent_panel, cx| { + agent_panel.open_saved_prompt_editor(context.path.clone(), window, cx) + }) + } + }; - cx.notify(); - } + if let Some(task) = task_result.log_err() { + task.detach_and_log_err(cx); + }; + + cx.notify(); } fn remove_selected_thread( @@ -460,20 +461,25 @@ impl AcpThreadHistory { _window: &mut Window, cx: &mut Context, ) { - if let Some(entry) = self.get_match(self.selected_index) { - let task_result = match entry { - HistoryEntry::Thread(thread) => todo!(), - HistoryEntry::Context(context) => self - .agent_panel - .update(cx, |this, cx| this.delete_context(context.path.clone(), cx)), - }; + self.remove_thread(self.selected_index, cx) + } - if let Some(task) = task_result.log_err() { - task.detach_and_log_err(cx); - }; + fn remove_thread(&mut self, ix: usize, cx: &mut Context) { + let Some(entry) = self.get_match(ix) else { + return; + }; + let task_result = match entry { + HistoryEntry::Thread(thread) => todo!(), + HistoryEntry::Context(context) => self + .agent_panel + .update(cx, |this, cx| this.delete_context(context.path.clone(), cx)), + }; - cx.notify(); - } + if let Some(task) = task_result.log_err() { + task.detach_and_log_err(cx); + }; + + cx.notify(); } fn list_items( @@ -482,8 +488,6 @@ impl AcpThreadHistory { _window: &mut Window, cx: &mut Context, ) -> Vec { - let range_start = range.start; - match &self.search_state { SearchState::Empty => self .separated_items @@ -492,22 +496,20 @@ impl AcpThreadHistory { .flat_map(|items| { items .iter() - .map(|item| self.render_list_item(item.entry_index(), item, vec![], cx)) + .map(|item| self.render_list_item(item, vec![], cx)) }) .collect(), SearchState::Searched { matches, .. } => matches[range] .iter() - .enumerate() - .map(|(ix, m)| { - self.render_list_item( - Some(range_start + ix), - &ListItemType::Entry { - index: m.candidate_id, - format: EntryTimeFormat::DateAndTime, - }, + .filter_map(|m| { + let entry = self.all_entries.get(m.candidate_id)?; + Some(self.render_history_entry( + entry, + EntryTimeFormat::DateAndTime, + m.candidate_id, m.positions.clone(), cx, - ) + )) }) .collect(), SearchState::Searching { .. } => { @@ -518,33 +520,14 @@ impl AcpThreadHistory { fn render_list_item( &self, - list_entry_ix: Option, item: &ListItemType, highlight_positions: Vec, cx: &Context, ) -> AnyElement { match item { ListItemType::Entry { index, format } => match self.all_entries.get(*index) { - Some(entry) => h_flex() - .w_full() - .pb_1() - .child( - HistoryEntryElement::new(entry.clone(), self.agent_panel.clone()) - .highlight_positions(highlight_positions) - .timestamp_format(*format) - .selected(list_entry_ix == Some(self.selected_index)) - .hovered(list_entry_ix == self.hovered_index) - .on_hover(cx.listener(move |this, is_hovered, _window, cx| { - if *is_hovered { - this.hovered_index = list_entry_ix; - } else if this.hovered_index == list_entry_ix { - this.hovered_index = None; - } - - cx.notify(); - })) - .into_any_element(), - ) + Some(entry) => self + .render_history_entry(entry, *format, *index, highlight_positions, cx) .into_any(), None => Empty.into_any_element(), }, @@ -560,6 +543,75 @@ impl AcpThreadHistory { .into_any_element(), } } + + fn render_history_entry( + &self, + entry: &HistoryEntry, + format: EntryTimeFormat, + list_entry_ix: usize, + highlight_positions: Vec, + cx: &Context, + ) -> AnyElement { + let selected = list_entry_ix == self.selected_index; + let hovered = Some(list_entry_ix) == self.hovered_index; + let timestamp = entry.updated_at().timestamp(); + let thread_timestamp = format.format_timestamp(timestamp, self.local_timezone); + + h_flex() + .w_full() + .pb_1() + .child( + ListItem::new(list_entry_ix) + .rounded() + .toggle_state(selected) + .spacing(ListItemSpacing::Sparse) + .start_slot( + h_flex() + .w_full() + .gap_2() + .justify_between() + .child( + HighlightedLabel::new(entry.title(), highlight_positions) + .size(LabelSize::Small) + .truncate(), + ) + .child( + Label::new(thread_timestamp) + .color(Color::Muted) + .size(LabelSize::XSmall), + ), + ) + .on_hover(cx.listener(move |this, is_hovered, _window, cx| { + if *is_hovered { + this.hovered_index = Some(list_entry_ix); + } else if this.hovered_index == Some(list_entry_ix) { + this.hovered_index = None; + } + + cx.notify(); + })) + .end_slot::(if hovered || selected { + Some( + IconButton::new("delete", IconName::Trash) + .shape(IconButtonShape::Square) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .tooltip(move |window, cx| { + Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx) + }) + .on_click(cx.listener(move |this, _, _, cx| { + this.remove_thread(list_entry_ix, cx) + })), + ) + } else { + None + }) + .on_click(cx.listener(move |this, _, window, cx| { + this.confirm_entry(list_entry_ix, window, cx) + })), + ) + .into_any_element() + } } impl Focusable for AcpThreadHistory { @@ -641,174 +693,6 @@ impl Render for AcpThreadHistory { } } -#[derive(IntoElement)] -pub struct HistoryEntryElement { - entry: HistoryEntry, - agent_panel: WeakEntity, - selected: bool, - hovered: bool, - highlight_positions: Vec, - timestamp_format: EntryTimeFormat, - on_hover: Box, -} - -impl HistoryEntryElement { - pub fn new(entry: HistoryEntry, agent_panel: WeakEntity) -> Self { - Self { - entry, - agent_panel, - selected: false, - hovered: false, - highlight_positions: vec![], - timestamp_format: EntryTimeFormat::DateAndTime, - on_hover: Box::new(|_, _, _| {}), - } - } - - pub fn selected(mut self, selected: bool) -> Self { - self.selected = selected; - self - } - - pub fn hovered(mut self, hovered: bool) -> Self { - self.hovered = hovered; - self - } - - pub fn highlight_positions(mut self, positions: Vec) -> Self { - self.highlight_positions = positions; - self - } - - pub fn on_hover(mut self, on_hover: impl Fn(&bool, &mut Window, &mut App) + 'static) -> Self { - self.on_hover = Box::new(on_hover); - self - } - - pub fn timestamp_format(mut self, format: EntryTimeFormat) -> Self { - self.timestamp_format = format; - self - } -} - -impl RenderOnce for HistoryEntryElement { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - let (id, summary, timestamp) = match &self.entry { - HistoryEntry::Thread(thread) => ( - thread.id.to_string(), - thread.title.clone(), - thread.updated_at.timestamp(), - ), - HistoryEntry::Context(context) => ( - context.path.to_string_lossy().to_string(), - context.title.clone(), - context.mtime.timestamp(), - ), - }; - - let thread_timestamp = - self.timestamp_format - .format_timestamp(&self.agent_panel, timestamp, cx); - - ListItem::new(SharedString::from(id)) - .rounded() - .toggle_state(self.selected) - .spacing(ListItemSpacing::Sparse) - .start_slot( - h_flex() - .w_full() - .gap_2() - .justify_between() - .child( - HighlightedLabel::new(summary, self.highlight_positions) - .size(LabelSize::Small) - .truncate(), - ) - .child( - Label::new(thread_timestamp) - .color(Color::Muted) - .size(LabelSize::XSmall), - ), - ) - .on_hover(self.on_hover) - .end_slot::(if self.hovered || self.selected { - Some( - IconButton::new("delete", IconName::Trash) - .shape(IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .tooltip(move |window, cx| { - Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx) - }) - .on_click({ - let agent_panel = self.agent_panel.clone(); - - let f: Box = - match &self.entry { - HistoryEntry::Thread(thread) => { - let id = thread.id.clone(); - - Box::new(move |_event, _window, cx| { - agent_panel - .update(cx, |agent_panel, cx| { - todo!() - // this.delete_thread(&id, cx) - // .detach_and_log_err(cx); - }) - .ok(); - }) - } - HistoryEntry::Context(context) => { - let path = context.path.clone(); - - Box::new(move |_event, _window, cx| { - agent_panel - .update(cx, |this, cx| { - this.delete_context(path.clone(), cx) - .detach_and_log_err(cx); - }) - .ok(); - }) - } - }; - f - }), - ) - } else { - None - }) - .on_click({ - let agent_panel = self.agent_panel.clone(); - - let f: Box = match &self.entry - { - HistoryEntry::Thread(thread) => { - let id = thread.id.clone(); - Box::new(move |_event, window, cx| { - agent_panel - .update(cx, |agent_panel, cx| { - // todo!() - }) - .ok(); - }) - } - HistoryEntry::Context(context) => { - let path = context.path.clone(); - Box::new(move |_event, window, cx| { - agent_panel - .update(cx, |this, cx| { - this.open_saved_prompt_editor(path.clone(), window, cx) - .detach_and_log_err(cx); - }) - .ok(); - }) - } - }; - f - }) - } -} - #[derive(Clone, Copy)] pub enum EntryTimeFormat { DateAndTime, @@ -816,18 +700,10 @@ pub enum EntryTimeFormat { } impl EntryTimeFormat { - fn format_timestamp( - &self, - agent_panel: &WeakEntity, - timestamp: i64, - cx: &App, - ) -> String { + fn format_timestamp(&self, timestamp: i64, timezone: UtcOffset) -> String { let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap(); - let timezone = agent_panel - .read_with(cx, |this, _cx| this.local_timezone()) - .unwrap_or(UtcOffset::UTC); - match &self { + match self { EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp( timestamp, OffsetDateTime::now_utc(), From eebe425c1de9dc23abe5a63ca2e5a99b1a6025b8 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 15 Aug 2025 22:59:46 -0600 Subject: [PATCH 07/25] Tidy more --- crates/agent2/src/db.rs | 18 +++++++-- crates/agent_ui/src/acp.rs | 2 +- crates/agent_ui/src/acp/thread_history.rs | 49 +++++++++++++---------- crates/agent_ui/src/acp/thread_view.rs | 6 ++- crates/agent_ui/src/agent_panel.rs | 33 ++++++++++++--- 5 files changed, 73 insertions(+), 35 deletions(-) diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index d40352f257..a7240df5c7 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -6,7 +6,7 @@ use anyhow::{Result, anyhow}; use chrono::{DateTime, Utc}; use collections::{HashMap, IndexMap}; use futures::{FutureExt, future::Shared}; -use gpui::{BackgroundExecutor, Global, ReadGlobal, Task}; +use gpui::{BackgroundExecutor, Global, Task}; use indoc::indoc; use parking_lot::Mutex; use serde::{Deserialize, Serialize}; @@ -15,7 +15,7 @@ use sqlez::{ connection::Connection, statement::Statement, }; -use std::{path::PathBuf, sync::Arc}; +use std::sync::Arc; use ui::{App, SharedString}; pub type DbMessage = crate::Message; @@ -221,6 +221,10 @@ pub(crate) struct ThreadsDatabase { connection: Arc>, } +struct GlobalThreadsDatabase(Shared, Arc>>>); + +impl Global for GlobalThreadsDatabase {} + impl ThreadsDatabase { fn connection(&self) -> Arc> { self.connection.clone() @@ -231,8 +235,11 @@ impl ThreadsDatabase { impl ThreadsDatabase { pub fn connect(cx: &mut App) -> Shared, Arc>>> { + if cx.has_global::() { + return cx.global::().0.clone(); + } let executor = cx.background_executor().clone(); - executor + let task = executor .spawn({ let executor = executor.clone(); async move { @@ -242,7 +249,10 @@ impl ThreadsDatabase { } } }) - .shared() + .shared(); + + cx.set_global(GlobalThreadsDatabase(task.clone())); + task } pub fn new(executor: BackgroundExecutor) -> Result { diff --git a/crates/agent_ui/src/acp.rs b/crates/agent_ui/src/acp.rs index c3fe90760e..efdeee9efd 100644 --- a/crates/agent_ui/src/acp.rs +++ b/crates/agent_ui/src/acp.rs @@ -8,5 +8,5 @@ mod thread_view; pub use model_selector::AcpModelSelector; pub use model_selector_popover::AcpModelSelectorPopover; -pub use thread_history::AcpThreadHistory; +pub use thread_history::{AcpThreadHistory, ThreadHistoryEvent}; pub use thread_view::AcpThreadView; diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index 4bf409a2af..b2fd1bfcb2 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -8,7 +8,7 @@ use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; use editor::{Editor, EditorEvent}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ - App, Empty, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, + App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, UniformListScrollHandle, WeakEntity, Window, uniform_list, }; use project::Project; @@ -61,9 +61,14 @@ enum ListItemType { }, } +pub enum ThreadHistoryEvent { + Open(HistoryEntry), +} + +impl EventEmitter for AcpThreadHistory {} + impl AcpThreadHistory { pub(crate) fn new( - agent_panel: WeakEntity, project: &Entity, window: &mut Window, cx: &mut Context, @@ -117,7 +122,6 @@ impl AcpThreadHistory { let scrollbar_state = ScrollbarState::new(scroll_handle.clone()); let mut this = Self { - agent_panel, history_store, scroll_handle, selected_index: 0, @@ -429,28 +433,29 @@ impl AcpThreadHistory { ) } - fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - self.confirm_entry(self.selected_index, window, cx); + fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context) { + self.confirm_entry(self.selected_index, cx); } - fn confirm_entry(&mut self, ix: usize, window: &mut Window, cx: &mut Context) { + fn confirm_entry(&mut self, ix: usize, cx: &mut Context) { let Some(entry) = self.get_match(ix) else { return; }; - let task_result = match entry { - HistoryEntry::Thread(thread) => { - self.agent_panel.update(cx, move |agent_panel, cx| todo!()) - } - HistoryEntry::Context(context) => { - self.agent_panel.update(cx, move |agent_panel, cx| { - agent_panel.open_saved_prompt_editor(context.path.clone(), window, cx) - }) - } - }; + cx.emit(ThreadHistoryEvent::Open(entry.clone())); + // let task_result = match entry { + // HistoryEntry::Thread(thread) => { + // self.agent_panel.update(cx, move |agent_panel, cx| todo!()) + // } + // HistoryEntry::Context(context) => { + // self.agent_panel.update(cx, move |agent_panel, cx| { + // agent_panel.open_saved_prompt_editor(context.path.clone(), window, cx) + // }) + // } + // }; - if let Some(task) = task_result.log_err() { - task.detach_and_log_err(cx); - }; + // if let Some(task) = task_result.log_err() { + // task.detach_and_log_err(cx); + // }; cx.notify(); } @@ -606,9 +611,9 @@ impl AcpThreadHistory { } else { None }) - .on_click(cx.listener(move |this, _, window, cx| { - this.confirm_entry(list_entry_ix, window, cx) - })), + .on_click( + cx.listener(move |this, _, _, cx| this.confirm_entry(list_entry_ix, cx)), + ), ) .into_any_element() } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index af97f944e9..13a6fd1a21 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,6 +1,7 @@ use acp_thread::{ - AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, - LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, UserMessageId, + AcpThread, AcpThreadEvent, AcpThreadMetadata, AgentThreadEntry, AssistantMessage, + AssistantMessageChunk, LoadError, MentionUri, ThreadStatus, ToolCall, ToolCallContent, + ToolCallStatus, UserMessageId, }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; @@ -136,6 +137,7 @@ impl AcpThreadView { project: Entity, thread_store: Entity, text_thread_store: Entity, + restore_thread: Option, window: &mut Window, cx: &mut Context, ) -> Self { diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index f54ea9b31a..7030df1596 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -4,13 +4,13 @@ use std::rc::Rc; use std::sync::Arc; use std::time::Duration; +use acp_thread::AcpThreadMetadata; use agent_servers::AgentServer; -use agent2::NativeAgentServer; +use agent2::history_store::HistoryEntry; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use serde::{Deserialize, Serialize}; -use crate::NewExternalAgentThread; -use crate::acp::AcpThreadHistory; +use crate::acp::{AcpThreadHistory, ThreadHistoryEvent}; use crate::agent_diff::AgentDiffThread; use crate::{ AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode, @@ -31,6 +31,7 @@ use crate::{ thread_history::{HistoryEntryElement, ThreadHistory}, ui::{AgentOnboardingModal, EndTrialUpsell}, }; +use crate::{ExternalAgent, NewExternalAgentThread}; use agent::{ Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, context_store::ContextStore, @@ -121,7 +122,7 @@ pub fn init(cx: &mut App) { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); panel.update(cx, |panel, cx| { - panel.new_external_thread(action.agent, window, cx) + panel.new_external_thread(action.agent, None, window, cx) }); } }) @@ -746,8 +747,26 @@ impl AgentPanel { ) }); - let acp_history = - cx.new(|cx| AcpThreadHistory::new(weak_self.clone(), &project, window, cx)); + let acp_history = cx.new(|cx| AcpThreadHistory::new(&project, window, cx)); + cx.subscribe_in( + &acp_history, + window, + |this, _, event, window, cx| match event { + ThreadHistoryEvent::Open(HistoryEntry::Thread(thread)) => { + let agent_choice = match thread.agent.0.as_ref() { + "Claude Code" => Some(ExternalAgent::ClaudeCode), + "Gemini" => Some(ExternalAgent::Gemini), + "Native Agent" => Some(ExternalAgent::NativeAgent), + _ => None, + }; + this.new_external_thread(agent_choice, Some(thread.clone()), window, cx); + } + ThreadHistoryEvent::Open(HistoryEntry::Context(thread)) => { + todo!() + } + }, + ) + .detach(); Self { active_view, @@ -962,6 +981,7 @@ impl AgentPanel { fn new_external_thread( &mut self, agent_choice: Option, + restore_thread: Option, window: &mut Window, cx: &mut Context, ) { @@ -1019,6 +1039,7 @@ impl AgentPanel { project, thread_store.clone(), text_thread_store.clone(), + restore_thread, window, cx, ) From fa6c0a1a49c6557ec29ebf87c8af78df474c6215 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 15 Aug 2025 23:43:16 -0600 Subject: [PATCH 08/25] More progress --- crates/acp_thread/src/connection.rs | 10 ++ crates/agent2/src/agent.rs | 115 +++++++++++++++++++++- crates/agent_ui/src/acp/thread_history.rs | 20 ++-- crates/agent_ui/src/acp/thread_view.rs | 35 +++++-- 4 files changed, 158 insertions(+), 22 deletions(-) diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index dee18378ef..8446ebb074 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -31,6 +31,16 @@ pub trait AgentConnection { return None; } + fn load_thread( + self: Rc, + _project: Entity, + _cwd: &Path, + _session_id: acp::SessionId, + _cx: &mut App, + ) -> Task>> { + Task::ready(Err(anyhow::anyhow!("load thread not implemented"))) + } + fn auth_methods(&self) -> &[acp::AuthMethod]; fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index be8428afc5..89317b9210 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,4 +1,3 @@ -use crate::ThreadsDatabase; use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME; use crate::{ AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, @@ -6,7 +5,8 @@ use crate::{ MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, }; -use acp_thread::{AcpThreadMetadata, AgentModelSelector}; +use crate::{DbThread, ThreadsDatabase}; +use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; @@ -18,7 +18,7 @@ use futures::{SinkExt, StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; -use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry}; +use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel}; use project::{Project, ProjectItem, ProjectPath, Worktree}; use prompt_store::{ ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, @@ -759,7 +759,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection { } fn list_threads(&self, cx: &mut App) -> Option>> { - dbg!("listing!"); let (mut tx, rx) = futures::channel::mpsc::unbounded(); let database = self.0.update(cx, |this, _| { this.history_listeners.push(tx.clone()); @@ -790,6 +789,114 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Some(rx) } + fn load_thread( + self: Rc, + project: Entity, + cwd: &Path, + session_id: acp::SessionId, + cx: &mut App, + ) -> Task>> { + let database = self.0.update(cx, |this, _| this.thread_database.clone()); + cx.spawn(async move |cx| { + let database = database.await.map_err(|e| anyhow!(e))?; + let db_thread = database + .load_thread(session_id.clone()) + .await? + .context("no such thread found")?; + + let acp_thread = cx.update(|cx| { + cx.new(|cx| { + acp_thread::AcpThread::new( + db_thread.title, + self.clone(), + project.clone(), + session_id.clone(), + cx, + ) + }) + })?; + let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; + let agent = self.0.clone(); + + // Create Thread + let thread = agent.update( + cx, + |agent, cx: &mut gpui::Context| -> Result<_> { + let configured_model = LanguageModelRegistry::global(cx) + .update(cx, |registry, cx| { + db_thread + .model + .and_then(|model| { + let model = SelectedModel { + provider: model.provider.clone().into(), + model: model.model.clone().into(), + }; + registry.select_model(&model, cx) + }) + .or_else(|| registry.default_model()) + }) + .context("no default model configured")?; + + let model = agent + .models + .model_from_id(&LanguageModels::model_id(&configured_model.model)) + .context("no model by id")?; + + let thread = cx.new(|cx| { + let mut thread = Thread::new( + project.clone(), + agent.project_context.clone(), + agent.context_server_registry.clone(), + action_log.clone(), + agent.templates.clone(), + model, + cx, + ); + // todo!() factor this out + thread.add_tool(CopyPathTool::new(project.clone())); + thread.add_tool(CreateDirectoryTool::new(project.clone())); + thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); + thread.add_tool(DiagnosticsTool::new(project.clone())); + thread.add_tool(EditFileTool::new(cx.entity())); + thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); + thread.add_tool(FindPathTool::new(project.clone())); + thread.add_tool(GrepTool::new(project.clone())); + thread.add_tool(ListDirectoryTool::new(project.clone())); + thread.add_tool(MovePathTool::new(project.clone())); + thread.add_tool(NowTool); + thread.add_tool(OpenTool::new(project.clone())); + thread.add_tool(ReadFileTool::new(project.clone(), action_log)); + thread.add_tool(TerminalTool::new(project.clone(), cx)); + thread.add_tool(ThinkingTool); + thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. + thread + }); + + Ok(thread) + }, + )??; + + // Store the session + agent.update(cx, |agent, cx| { + agent.sessions.insert( + session_id, + Session { + thread, + acp_thread: acp_thread.downgrade(), + _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }), + }, + ); + })?; + + // we need to actually deserialize the DbThread. + todo!() + + Ok(acp_thread) + }) + } + fn model_selector(&self) -> Option> { Some(Rc::new(self.clone()) as Rc) } diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index b2fd1bfcb2..f4750281f9 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -21,7 +21,6 @@ use ui::{ use util::ResultExt; pub struct AcpThreadHistory { - agent_panel: WeakEntity, history_store: Entity, scroll_handle: UniformListScrollHandle, selected_index: usize, @@ -473,16 +472,17 @@ impl AcpThreadHistory { let Some(entry) = self.get_match(ix) else { return; }; - let task_result = match entry { - HistoryEntry::Thread(thread) => todo!(), - HistoryEntry::Context(context) => self - .agent_panel - .update(cx, |this, cx| this.delete_context(context.path.clone(), cx)), - }; + todo!(); + // let task_result = match entry { + // HistoryEntry::Thread(thread) => todo!(), + // HistoryEntry::Context(context) => self + // .agent_panel + // .update(cx, |this, cx| this.delete_context(context.path.clone(), cx)), + // }; - if let Some(task) = task_result.log_err() { - task.detach_and_log_err(cx); - }; + // if let Some(task) = task_result.log_err() { + // task.detach_and_log_err(cx); + // }; cx.notify(); } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 13a6fd1a21..52fbb9002d 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -169,7 +169,14 @@ impl AcpThreadView { project: project.clone(), thread_store, text_thread_store, - thread_state: Self::initial_state(agent, workspace, project, window, cx), + thread_state: Self::initial_state( + agent, + restore_thread, + workspace, + project, + window, + cx, + ), message_editor, model_selector: None, notifications: Vec::new(), @@ -193,6 +200,7 @@ impl AcpThreadView { fn initial_state( agent: Rc, + restore_thread: Option, workspace: WeakEntity, project: Entity, window: &mut Window, @@ -232,19 +240,27 @@ impl AcpThreadView { // .detach(); // }) // .ok(); - - let Some(result) = cx - .update(|_, cx| { + // + let task = cx.update(|_, cx| { + if let Some(restore_thread) = restore_thread { + connection.clone().load_thread( + project.clone(), + &root_dir, + restore_thread.id, + cx, + ) + } else { connection .clone() .new_thread(project.clone(), &root_dir, cx) - }) - .log_err() - else { + } + }); + + let Ok(task) = task else { return; }; - let result = match result.await { + let result = match task.await { Err(e) => { let mut cx = cx.clone(); if e.is::() { @@ -618,6 +634,7 @@ impl AcpThreadView { } else { this.thread_state = Self::initial_state( agent, + None, // todo!() this.workspace.clone(), project.clone(), window, @@ -3733,6 +3750,7 @@ pub(crate) mod tests { project, thread_store.clone(), text_thread_store.clone(), + None, window, cx, ) @@ -3884,6 +3902,7 @@ pub(crate) mod tests { project.clone(), thread_store.clone(), text_thread_store.clone(), + None, window, cx, ) From fae5900749c9347021d4d226ae59fa7b6d7b5a86 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Sun, 17 Aug 2025 13:24:27 -0600 Subject: [PATCH 09/25] factor otu --- crates/agent2/src/agent.rs | 59 ++++++++++++++++--------------------- crates/agent2/src/thread.rs | 4 +-- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 89317b9210..65f6e38c56 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -563,6 +563,30 @@ impl NativeAgentConnection { }) }) } + + fn register_tools( + thread: &mut Thread, + project: Entity, + action_log: Entity, + cx: &mut Context, + ) { + thread.add_tool(CopyPathTool::new(project.clone())); + thread.add_tool(CreateDirectoryTool::new(project.clone())); + thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); + thread.add_tool(DiagnosticsTool::new(project.clone())); + thread.add_tool(EditFileTool::new(cx.entity())); + thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); + thread.add_tool(FindPathTool::new(project.clone())); + thread.add_tool(GrepTool::new(project.clone())); + thread.add_tool(ListDirectoryTool::new(project.clone())); + thread.add_tool(MovePathTool::new(project.clone())); + thread.add_tool(NowTool); + thread.add_tool(OpenTool::new(project.clone())); + thread.add_tool(ReadFileTool::new(project.clone(), action_log)); + thread.add_tool(TerminalTool::new(project.clone(), cx)); + thread.add_tool(ThinkingTool); + thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. + } } impl AgentModelSelector for NativeAgentConnection { @@ -709,22 +733,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { default_model, cx, ); - thread.add_tool(CopyPathTool::new(project.clone())); - thread.add_tool(CreateDirectoryTool::new(project.clone())); - thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); - thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(EditFileTool::new(cx.entity())); - thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); - thread.add_tool(FindPathTool::new(project.clone())); - thread.add_tool(GrepTool::new(project.clone())); - thread.add_tool(ListDirectoryTool::new(project.clone())); - thread.add_tool(MovePathTool::new(project.clone())); - thread.add_tool(NowTool); - thread.add_tool(OpenTool::new(project.clone())); - thread.add_tool(ReadFileTool::new(project.clone(), action_log)); - thread.add_tool(TerminalTool::new(project.clone(), cx)); - thread.add_tool(ThinkingTool); - thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. + Self::register_tools(&mut thread, project, action_log, cx); thread }); @@ -852,23 +861,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { model, cx, ); - // todo!() factor this out - thread.add_tool(CopyPathTool::new(project.clone())); - thread.add_tool(CreateDirectoryTool::new(project.clone())); - thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); - thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(EditFileTool::new(cx.entity())); - thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); - thread.add_tool(FindPathTool::new(project.clone())); - thread.add_tool(GrepTool::new(project.clone())); - thread.add_tool(ListDirectoryTool::new(project.clone())); - thread.add_tool(MovePathTool::new(project.clone())); - thread.add_tool(NowTool); - thread.add_tool(OpenTool::new(project.clone())); - thread.add_tool(ReadFileTool::new(project.clone(), action_log)); - thread.add_tool(TerminalTool::new(project.clone(), cx)); - thread.add_tool(ThinkingTool); - thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. + Self::register_tools(&mut thread, project, action_log, cx); thread }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 784477d677..33f0208ab7 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -441,7 +441,7 @@ impl Thread { cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); - Self { + let this = Self { messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, @@ -455,7 +455,7 @@ impl Thread { model, project, action_log, - } + }; } pub fn project(&self) -> &Entity { From a231fd3ee5c87a9110331c5d5226025ec7c15848 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 10:13:02 +0200 Subject: [PATCH 10/25] Take a weak thread in EditFileTool to avoid cycle --- crates/agent2/src/agent.rs | 77 ++++++++++----------- crates/agent2/src/thread.rs | 4 +- crates/agent2/src/tools/edit_file_tool.rs | 83 ++++++++++++++--------- 3 files changed, 89 insertions(+), 75 deletions(-) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 65f6e38c56..074dffe1dc 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -574,7 +574,7 @@ impl NativeAgentConnection { thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(EditFileTool::new(cx.entity())); + thread.add_tool(EditFileTool::new(cx.weak_entity())); thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(GrepTool::new(project.clone())); @@ -801,7 +801,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { fn load_thread( self: Rc, project: Entity, - cwd: &Path, + _cwd: &Path, session_id: acp::SessionId, cx: &mut App, ) -> Task>> { @@ -828,46 +828,43 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let agent = self.0.clone(); // Create Thread - let thread = agent.update( - cx, - |agent, cx: &mut gpui::Context| -> Result<_> { - let configured_model = LanguageModelRegistry::global(cx) - .update(cx, |registry, cx| { - db_thread - .model - .and_then(|model| { - let model = SelectedModel { - provider: model.provider.clone().into(), - model: model.model.clone().into(), - }; - registry.select_model(&model, cx) - }) - .or_else(|| registry.default_model()) - }) - .context("no default model configured")?; + let thread = agent.update(cx, |agent, cx| { + let configured_model = LanguageModelRegistry::global(cx) + .update(cx, |registry, cx| { + db_thread + .model + .and_then(|model| { + let model = SelectedModel { + provider: model.provider.clone().into(), + model: model.model.clone().into(), + }; + registry.select_model(&model, cx) + }) + .or_else(|| registry.default_model()) + }) + .context("no default model configured")?; - let model = agent - .models - .model_from_id(&LanguageModels::model_id(&configured_model.model)) - .context("no model by id")?; + let model = agent + .models + .model_from_id(&LanguageModels::model_id(&configured_model.model)) + .context("no model by id")?; - let thread = cx.new(|cx| { - let mut thread = Thread::new( - project.clone(), - agent.project_context.clone(), - agent.context_server_registry.clone(), - action_log.clone(), - agent.templates.clone(), - model, - cx, - ); - Self::register_tools(&mut thread, project, action_log, cx); - thread - }); + let thread = cx.new(|cx| { + let mut thread = Thread::new( + project.clone(), + agent.project_context.clone(), + agent.context_server_registry.clone(), + action_log.clone(), + agent.templates.clone(), + model, + cx, + ); + Self::register_tools(&mut thread, project, action_log, cx); + thread + }); - Ok(thread) - }, - )??; + anyhow::Ok(thread) + })??; // Store the session agent.update(cx, |agent, cx| { @@ -884,7 +881,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { })?; // we need to actually deserialize the DbThread. - todo!() + // todo!() Ok(acp_thread) }) diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 33f0208ab7..784477d677 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -441,7 +441,7 @@ impl Thread { cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); - let this = Self { + Self { messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, @@ -455,7 +455,7 @@ impl Thread { model, project, action_log, - }; + } } pub fn project(&self) -> &Entity { diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index c77b9f6a69..6462308918 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use cloud_llm_client::CompletionIntent; use collections::HashSet; -use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use indoc::formatdoc; use language::ToPoint; use language::language_settings::{self, FormatOnSave}; @@ -122,11 +122,11 @@ impl From for LanguageModelToolResultContent { } pub struct EditFileTool { - thread: Entity, + thread: WeakEntity, } impl EditFileTool { - pub fn new(thread: Entity) -> Self { + pub fn new(thread: WeakEntity) -> Self { Self { thread } } @@ -167,8 +167,11 @@ impl EditFileTool { // Check if path is inside the global config directory // First check if it's already inside project - if not, try to canonicalize - let thread = self.thread.read(cx); - let project_path = thread.project().read(cx).find_project_path(&input.path, cx); + let Ok(project_path) = self.thread.read_with(cx, |thread, cx| { + thread.project().read(cx).find_project_path(&input.path, cx) + }) else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; // If the path is inside the project, and it's not one of the above edge cases, // then no confirmation is necessary. Otherwise, confirmation is necessary. @@ -221,7 +224,12 @@ impl AgentTool for EditFileTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { - let project = self.thread.read(cx).project().clone(); + let Ok(project) = self + .thread + .read_with(cx, |thread, _cx| thread.project().clone()) + else { + return Task::ready(Err(anyhow!("thread was dropped"))); + }; let project_path = match resolve_path(&input, project.clone(), cx) { Ok(path) => path, Err(err) => return Task::ready(Err(anyhow!(err))), @@ -237,17 +245,15 @@ impl AgentTool for EditFileTool { }); } - let request = self.thread.update(cx, |thread, cx| { - thread.build_completion_request(CompletionIntent::ToolResults, cx) - }); - let thread = self.thread.read(cx); - let model = thread.model().clone(); - let action_log = thread.action_log().clone(); - let authorize = self.authorize(&input, &event_stream, cx); cx.spawn(async move |cx: &mut AsyncApp| { authorize.await?; + let (request, model, action_log) = self.thread.update(cx, |thread, cx| { + let request = thread.build_completion_request(CompletionIntent::ToolResults, cx); + (request, thread.model().clone(), thread.action_log().clone()) + })?; + let edit_format = EditFormat::from_model(model.clone())?; let edit_agent = EditAgent::new( model, @@ -531,7 +537,11 @@ mod tests { path: "root/nonexistent_file.txt".into(), mode: EditFileMode::Edit, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }) .await; assert_eq!( @@ -744,10 +754,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) - .run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -800,7 +811,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the unformatted content @@ -881,10 +896,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) - .run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -932,10 +948,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool { - thread: thread.clone(), - }) - .run(input, ToolCallEventStream::test().0, cx) + Arc::new(EditFileTool::new(thread.downgrade())).run( + input, + ToolCallEventStream::test().0, + cx, + ) }); // Stream the content with trailing whitespace @@ -983,7 +1000,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); fs.insert_tree("/root", json!({})).await; // Test 1: Path with .zed component should require confirmation @@ -1114,7 +1131,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); // Test global config paths - these should require confirmation if they exist and are outside the project let test_cases = vec![ @@ -1224,7 +1241,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); // Test files in different worktrees let test_cases = vec![ @@ -1305,7 +1322,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); // Test edge cases let test_cases = vec![ @@ -1389,7 +1406,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); // Test different EditFileMode values let modes = vec![ @@ -1470,7 +1487,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool { thread }); + let tool = Arc::new(EditFileTool::new(thread.downgrade())); assert_eq!( tool.initial_title(Err(json!({ From 5259c8692dde8366c2eda6ec250ebdff0e6e0601 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 13:45:55 +0200 Subject: [PATCH 11/25] WIP --- crates/agent2/src/agent.rs | 33 +-- crates/agent2/src/db.rs | 8 +- crates/agent2/src/tests/mod.rs | 24 +- crates/agent2/src/thread.rs | 257 ++++++++++++++---- .../src/tools/context_server_registry.rs | 10 + crates/agent2/src/tools/terminal_tool.rs | 4 +- 6 files changed, 250 insertions(+), 86 deletions(-) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index fa6201a483..2a17f23fd0 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,9 +1,9 @@ use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME; use crate::{ - AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, - DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, - MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, - ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, + ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool, + EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, + OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, + UserMessageContent, WebSearchTool, templates::Templates, }; use crate::{DbThread, ThreadsDatabase}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; @@ -461,10 +461,7 @@ impl NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, f: impl 'static - + FnOnce( - Entity, - &mut App, - ) -> Result>>, + + FnOnce(Entity, &mut App) -> Result>>, ) -> Task> { let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { agent @@ -488,7 +485,10 @@ impl NativeAgentConnection { log::trace!("Received completion event: {:?}", event); match event { - AgentResponseEvent::Text(text) => { + ThreadEvent::UserMessage(message) => { + todo!() + } + ThreadEvent::AgentText(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block( acp::ContentBlock::Text(acp::TextContent { @@ -500,7 +500,7 @@ impl NativeAgentConnection { ) })?; } - AgentResponseEvent::Thinking(text) => { + ThreadEvent::AgentThinking(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block( acp::ContentBlock::Text(acp::TextContent { @@ -512,7 +512,7 @@ impl NativeAgentConnection { ) })?; } - AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { + ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { tool_call, options, response, @@ -535,17 +535,17 @@ impl NativeAgentConnection { }) .detach(); } - AgentResponseEvent::ToolCall(tool_call) => { + ThreadEvent::ToolCall(tool_call) => { acp_thread.update(cx, |thread, cx| { thread.upsert_tool_call(tool_call, cx) })??; } - AgentResponseEvent::ToolCallUpdate(update) => { + ThreadEvent::ToolCallUpdate(update) => { acp_thread.update(cx, |thread, cx| { thread.update_tool_call(update, cx) })??; } - AgentResponseEvent::Stop(stop_reason) => { + ThreadEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse { stop_reason }); } @@ -786,7 +786,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .into_iter() .map(|thread| AcpThreadMetadata { agent: NATIVE_AGENT_SERVER_NAME.clone(), - id: thread.id, + id: thread.id.into(), title: thread.title, updated_at: thread.updated_at, }) @@ -806,11 +806,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, ) -> Task>> { + let thread_id = session_id.clone().into(); let database = self.0.update(cx, |this, _| this.thread_database.clone()); cx.spawn(async move |cx| { let database = database.await.map_err(|e| anyhow!(e))?; let db_thread = database - .load_thread(session_id.clone()) + .load_thread(thread_id) .await? .context("no such thread found")?; diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index a7240df5c7..5332da276f 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -1,4 +1,4 @@ -use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; +use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent}; use agent::thread_store; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; @@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbThreadMetadata { - pub id: acp::SessionId, + pub id: ThreadId, #[serde(alias = "summary")] pub title: SharedString, pub updated_at: DateTime, @@ -323,7 +323,7 @@ impl ThreadsDatabase { for (id, summary, updated_at) in rows { threads.push(DbThreadMetadata { - id: acp::SessionId(id), + id: ThreadId(id), title: summary.into(), updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), }); @@ -333,7 +333,7 @@ impl ThreadsDatabase { }) } - pub fn load_thread(&self, id: acp::SessionId) -> Task>> { + pub fn load_thread(&self, id: ThreadId) -> Task>> { let connection = self.connection.clone(); self.executor.spawn(async move { diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 48a16bf685..5cbddafded 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -329,7 +329,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { let mut saw_partial_tool_use = false; while let Some(event) = events.next().await { - if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { + if let Ok(ThreadEvent::ToolCall(tool_call)) = event { thread.update(cx, |thread, _cx| { // Look for a tool use in the thread's last message let message = thread.last_message().unwrap(); @@ -710,7 +710,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { } async fn expect_tool_call( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCall { let event = events .next() @@ -718,7 +718,7 @@ async fn expect_tool_call( .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCall(tool_call) => return tool_call, + ThreadEvent::ToolCall(tool_call) => return tool_call, event => { panic!("Unexpected event {event:?}"); } @@ -726,7 +726,7 @@ async fn expect_tool_call( } async fn expect_tool_call_update_fields( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCallUpdate { let event = events .next() @@ -734,7 +734,7 @@ async fn expect_tool_call_update_fields( .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { return update; } event => { @@ -744,7 +744,7 @@ async fn expect_tool_call_update_fields( } async fn next_tool_call_authorization( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> ToolCallAuthorization { loop { let event = events @@ -752,7 +752,7 @@ async fn next_tool_call_authorization( .await .expect("no tool call authorization event received") .unwrap(); - if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { + if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event { let permission_kinds = tool_call_authorization .options .iter() @@ -912,13 +912,13 @@ async fn test_cancellation(cx: &mut TestAppContext) { let mut echo_completed = false; while let Some(event) = events.next().await { match event.unwrap() { - AgentResponseEvent::ToolCall(tool_call) => { + ThreadEvent::ToolCall(tool_call) => { assert_eq!(tool_call.title, expected_tools.remove(0)); if tool_call.title == "Echo" { echo_id = Some(tool_call.id); } } - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( acp::ToolCallUpdate { id, fields: @@ -946,7 +946,7 @@ async fn test_cancellation(cx: &mut TestAppContext) { assert!( matches!( last_event, - Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled))) ), "unexpected event {last_event:?}" ); @@ -1386,11 +1386,11 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { } /// Filters out the stop events for asserting against in tests -fn stop_events(result_events: Vec>) -> Vec { +fn stop_events(result_events: Vec>) -> Vec { result_events .into_iter() .filter_map(|event| match event.unwrap() { - AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), + ThreadEvent::Stop(stop_reason) => Some(stop_reason), _ => None, }) .collect() diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index bd6d78b2f0..87f4803daf 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,4 +1,4 @@ -use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; +use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates}; use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; use agent_client_protocol as acp; @@ -30,10 +30,12 @@ use std::{fmt::Write, ops::Range}; use util::{ResultExt, markdown::MarkdownCodeBlock}; use uuid::Uuid; +const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; + #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, )] -pub struct ThreadId(Arc); +pub struct ThreadId(pub(crate) Arc); impl ThreadId { pub fn new() -> Self { @@ -53,6 +55,18 @@ impl From<&str> for ThreadId { } } +impl From for ThreadId { + fn from(value: acp::SessionId) -> Self { + Self(value.0) + } +} + +impl From for acp::SessionId { + fn from(value: ThreadId) -> Self { + Self(value.0) + } +} + /// The ID of the user prompt that initiated a request. /// /// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). @@ -313,9 +327,6 @@ impl AgentMessage { AgentMessageContent::RedactedThinking(_) => { markdown.push_str("\n") } - AgentMessageContent::Image(_) => { - markdown.push_str("\n"); - } AgentMessageContent::ToolUse(tool_use) => { markdown.push_str(&format!( "**Tool Use**: {} (ID: {})\n", @@ -386,9 +397,6 @@ impl AgentMessage { AgentMessageContent::ToolUse(value) => { language_model::MessageContent::ToolUse(value.clone()) } - AgentMessageContent::Image(value) => { - language_model::MessageContent::Image(value.clone()) - } }; assistant_message.content.push(chunk); } @@ -432,14 +440,14 @@ pub enum AgentMessageContent { signature: Option, }, RedactedThinking(String), - Image(LanguageModelImage), ToolUse(LanguageModelToolUse), } #[derive(Debug)] -pub enum AgentResponseEvent { - Text(String), - Thinking(String), +pub enum ThreadEvent { + UserMessage(UserMessage), + AgentText(String), + AgentThinking(String), ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), @@ -504,6 +512,121 @@ impl Thread { } } + pub fn from_db( + id: ThreadId, + db_thread: DbThread, + project: Entity, + project_context: Rc>, + context_server_registry: Entity, + action_log: Entity, + templates: Arc, + model: Arc, + cx: &mut Context, + ) -> Self { + let profile_id = db_thread + .profile + .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); + Self { + id, + prompt_id: PromptId::new(), + messages: db_thread.messages, + completion_mode: CompletionMode::Normal, + running_turn: None, + pending_message: None, + tools: BTreeMap::default(), + tool_use_limit_reached: false, + context_server_registry, + profile_id, + project_context, + templates, + model, + project, + action_log, + } + } + + pub fn replay(&self, cx: &mut Context) -> mpsc::UnboundedReceiver> { + let (tx, rx) = mpsc::unbounded(); + let stream = ThreadEventStream(tx); + for message in &self.messages { + match message { + Message::User(user_message) => stream.send_user_message(&user_message), + Message::Agent(assistant_message) => { + for content in &assistant_message.content { + match content { + AgentMessageContent::Text(text) => stream.send_text(text), + AgentMessageContent::Thinking { text, .. } => { + stream.send_thinking(text) + } + AgentMessageContent::RedactedThinking(_) => {} + AgentMessageContent::ToolUse(tool_use) => { + self.replay_tool_call( + tool_use, + assistant_message.tool_results.get(&tool_use.id), + &stream, + cx, + ); + } + } + } + } + Message::Resume => {} + } + } + rx + } + + fn replay_tool_call( + &self, + tool_use: &LanguageModelToolUse, + tool_result: Option<&LanguageModelToolResult>, + stream: &ThreadEventStream, + cx: &mut Context, + ) { + let Some(tool) = self.tools.get(tool_use.name.as_ref()) else { + stream + .0 + .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + title: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Failed, + content: Vec::new(), + locations: Vec::new(), + raw_input: Some(tool_use.input.clone()), + raw_output: None, + }))) + .ok(); + return; + }; + + let title = tool.initial_title(tool_use.input.clone()); + let kind = tool.kind(); + stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + + if let Some(output) = tool_result + .as_ref() + .and_then(|result| result.output.clone()) + { + let tool_event_stream = ToolCallEventStream::new( + tool_use.id.clone(), + stream.clone(), + Some(self.project.read(cx).fs().clone()), + ); + tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) + .log_err(); + } else { + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + content: Some(vec![TOOL_CANCELED_MESSAGE.into()]), + status: Some(acp::ToolCallStatus::Failed), + ..Default::default() + }, + ); + } + } + pub fn project(&self) -> &Entity { &self.project } @@ -574,7 +697,7 @@ impl Thread { pub fn resume( &mut self, cx: &mut Context, - ) -> Result>> { + ) -> Result>> { anyhow::ensure!( self.tool_use_limit_reached, "can only resume after tool use limit is reached" @@ -595,7 +718,7 @@ impl Thread { id: UserMessageId, content: impl IntoIterator, cx: &mut Context, - ) -> mpsc::UnboundedReceiver> + ) -> mpsc::UnboundedReceiver> where T: Into, { @@ -613,15 +736,12 @@ impl Thread { self.run_turn(cx) } - fn run_turn( - &mut self, - cx: &mut Context, - ) -> mpsc::UnboundedReceiver> { + fn run_turn(&mut self, cx: &mut Context) -> mpsc::UnboundedReceiver> { self.cancel(); let model = self.model.clone(); - let (events_tx, events_rx) = mpsc::unbounded::>(); - let event_stream = AgentResponseEventStream(events_tx); + let (events_tx, events_rx) = mpsc::unbounded::>(); + let event_stream = ThreadEventStream(events_tx); let message_ix = self.messages.len().saturating_sub(1); self.tool_use_limit_reached = false; self.running_turn = Some(RunningTurn { @@ -755,7 +875,7 @@ impl Thread { fn handle_streamed_completion_event( &mut self, event: LanguageModelCompletionEvent, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) -> Option> { log::trace!("Handling streamed completion event: {:?}", event); @@ -797,7 +917,7 @@ impl Thread { fn handle_text_event( &mut self, new_text: String, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { event_stream.send_text(&new_text); @@ -818,7 +938,7 @@ impl Thread { &mut self, new_text: String, new_signature: Option, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { event_stream.send_thinking(&new_text); @@ -850,7 +970,7 @@ impl Thread { fn handle_tool_use_event( &mut self, tool_use: LanguageModelToolUse, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) -> Option> { cx.notify(); @@ -989,9 +1109,7 @@ impl Thread { tool_use_id: tool_use.id.clone(), tool_name: tool_use.name.clone(), is_error: true, - content: LanguageModelToolResultContent::Text( - "Tool canceled by user".into(), - ), + content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()), output: None, }, ); @@ -1143,7 +1261,7 @@ struct RunningTurn { _task: Task<()>, /// The current event stream for the running turn. Used to report a final /// cancellation event if we cancel the turn. - event_stream: AgentResponseEventStream, + event_stream: ThreadEventStream, } impl RunningTurn { @@ -1196,6 +1314,17 @@ where cx: &mut App, ) -> Task>; + /// Emits events for a previous execution of the tool. + fn replay( + &self, + _input: Self::Input, + _output: Self::Output, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } + fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) } @@ -1223,6 +1352,13 @@ pub trait AnyAgentTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task>; + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()>; } impl AnyAgentTool for Erased> @@ -1274,21 +1410,39 @@ where }) }) } + + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + let input = serde_json::from_value(input)?; + let output = serde_json::from_value(output)?; + self.0.replay(input, output, event_stream, cx) + } } #[derive(Clone)] -struct AgentResponseEventStream(mpsc::UnboundedSender>); +struct ThreadEventStream(mpsc::UnboundedSender>); + +impl ThreadEventStream { + fn send_user_message(&self, message: &UserMessage) { + self.0 + .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone()))) + .ok(); + } -impl AgentResponseEventStream { fn send_text(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string()))) .ok(); } fn send_thinking(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string()))) .ok(); } @@ -1300,7 +1454,7 @@ impl AgentResponseEventStream { input: serde_json::Value, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call( + .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call( id, title.to_string(), kind, @@ -1333,7 +1487,7 @@ impl AgentResponseEventStream { fields: acp::ToolCallUpdateFields, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp::ToolCallUpdate { id: acp::ToolCallId(tool_use_id.to_string().into()), fields, @@ -1347,17 +1501,17 @@ impl AgentResponseEventStream { match reason { StopReason::EndTurn => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn))) .ok(); } StopReason::MaxTokens => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens))) .ok(); } StopReason::Refusal => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal))) .ok(); } StopReason::ToolUse => {} @@ -1366,7 +1520,7 @@ impl AgentResponseEventStream { fn send_canceled(&self) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled))) .ok(); } @@ -1378,24 +1532,23 @@ impl AgentResponseEventStream { #[derive(Clone)] pub struct ToolCallEventStream { tool_use_id: LanguageModelToolUseId, - stream: AgentResponseEventStream, + stream: ThreadEventStream, fs: Option>, } impl ToolCallEventStream { #[cfg(test)] pub fn test() -> (Self, ToolCallEventStreamReceiver) { - let (events_tx, events_rx) = mpsc::unbounded::>(); + let (events_tx, events_rx) = mpsc::unbounded::>(); - let stream = - ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None); + let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None); (stream, ToolCallEventStreamReceiver(events_rx)) } fn new( tool_use_id: LanguageModelToolUseId, - stream: AgentResponseEventStream, + stream: ThreadEventStream, fs: Option>, ) -> Self { Self { @@ -1413,7 +1566,7 @@ impl ToolCallEventStream { pub fn update_diff(&self, diff: Entity) { self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp_thread::ToolCallUpdateDiff { id: acp::ToolCallId(self.tool_use_id.to_string().into()), diff, @@ -1426,7 +1579,7 @@ impl ToolCallEventStream { pub fn update_terminal(&self, terminal: Entity) { self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp_thread::ToolCallUpdateTerminal { id: acp::ToolCallId(self.tool_use_id.to_string().into()), terminal, @@ -1444,7 +1597,7 @@ impl ToolCallEventStream { let (response_tx, response_rx) = oneshot::channel(); self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( + .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization( ToolCallAuthorization { tool_call: acp::ToolCallUpdate { id: acp::ToolCallId(self.tool_use_id.to_string().into()), @@ -1494,13 +1647,13 @@ impl ToolCallEventStream { } #[cfg(test)] -pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); +pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); #[cfg(test)] impl ToolCallEventStreamReceiver { pub async fn expect_authorization(&mut self) -> ToolCallAuthorization { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event { + if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event { auth } else { panic!("Expected ToolCallAuthorization but got: {:?}", event); @@ -1509,9 +1662,9 @@ impl ToolCallEventStreamReceiver { pub async fn expect_terminal(&mut self) -> Entity { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallUpdate( - acp_thread::ToolCallUpdate::UpdateTerminal(update), - ))) = event + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal( + update, + )))) = event { update.terminal } else { @@ -1522,7 +1675,7 @@ impl ToolCallEventStreamReceiver { #[cfg(test)] impl std::ops::Deref for ToolCallEventStreamReceiver { - type Target = mpsc::UnboundedReceiver>; + type Target = mpsc::UnboundedReceiver>; fn deref(&self) -> &Self::Target { &self.0 diff --git a/crates/agent2/src/tools/context_server_registry.rs b/crates/agent2/src/tools/context_server_registry.rs index db39e9278c..3b5225317f 100644 --- a/crates/agent2/src/tools/context_server_registry.rs +++ b/crates/agent2/src/tools/context_server_registry.rs @@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool { }) }) } + + fn replay( + &self, + _input: serde_json::Value, + _output: serde_json::Value, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } } diff --git a/crates/agent2/src/tools/terminal_tool.rs b/crates/agent2/src/tools/terminal_tool.rs index ecb855ac34..6984475d18 100644 --- a/crates/agent2/src/tools/terminal_tool.rs +++ b/crates/agent2/src/tools/terminal_tool.rs @@ -319,7 +319,7 @@ mod tests { use theme::ThemeSettings; use util::test::TempTree; - use crate::AgentResponseEvent; + use crate::ThreadEvent; use super::*; @@ -396,7 +396,7 @@ mod tests { }); cx.run_until_parked(); let event = stream_rx.try_next(); - if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event { + if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event { auth.response.send(auth.options[0].id.clone()).unwrap(); } From 3a0e55d9b6fb9e9d48e759e3d3337fde5b6b84f2 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 14:16:19 +0200 Subject: [PATCH 12/25] Stream deserialized thread to `AcpThread` --- crates/agent2/src/agent.rs | 40 ++++++++++++++++++++++++--------- crates/agent2/src/thread.rs | 44 +++++++++++++++++++++++++++---------- 2 files changed, 63 insertions(+), 21 deletions(-) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 2a17f23fd0..5b27f0a048 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -5,7 +5,7 @@ use crate::{ OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, }; -use crate::{DbThread, ThreadsDatabase}; +use crate::{DbThread, ThreadId, ThreadsDatabase}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; @@ -473,10 +473,18 @@ impl NativeAgentConnection { }; log::debug!("Found session for: {}", session_id); - let mut response_stream = match f(thread, cx) { + let response_stream = match f(thread, cx) { Ok(stream) => stream, Err(err) => return Task::ready(Err(err)), }; + Self::handle_thread_events(response_stream, acp_thread, cx) + } + + fn handle_thread_events( + mut response_stream: mpsc::UnboundedReceiver>, + acp_thread: WeakEntity, + cx: &mut App, + ) -> Task> { cx.spawn(async move |cx| { // Handle response stream and forward to session.acp_thread while let Some(result) = response_stream.next().await { @@ -486,7 +494,15 @@ impl NativeAgentConnection { match event { ThreadEvent::UserMessage(message) => { - todo!() + acp_thread.update(cx, |thread, cx| { + for content in message.content { + thread.push_user_content_block( + Some(message.id.clone()), + content.into(), + cx, + ); + } + })?; } ThreadEvent::AgentText(text) => { acp_thread.update(cx, |thread, cx| { @@ -806,19 +822,19 @@ impl acp_thread::AgentConnection for NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, ) -> Task>> { - let thread_id = session_id.clone().into(); + let thread_id = ThreadId::from(session_id.clone()); let database = self.0.update(cx, |this, _| this.thread_database.clone()); cx.spawn(async move |cx| { let database = database.await.map_err(|e| anyhow!(e))?; let db_thread = database - .load_thread(thread_id) + .load_thread(thread_id.clone()) .await? .context("no such thread found")?; let acp_thread = cx.update(|cx| { cx.new(|cx| { acp_thread::AcpThread::new( - db_thread.title, + db_thread.title.clone(), self.clone(), project.clone(), session_id.clone(), @@ -835,6 +851,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .update(cx, |registry, cx| { db_thread .model + .as_ref() .and_then(|model| { let model = SelectedModel { provider: model.provider.clone().into(), @@ -852,7 +869,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .context("no model by id")?; let thread = cx.new(|cx| { - let mut thread = Thread::new( + let mut thread = Thread::from_db( + thread_id, + db_thread, project.clone(), agent.project_context.clone(), agent.context_server_registry.clone(), @@ -873,7 +892,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { agent.sessions.insert( session_id, Session { - thread, + thread: thread.clone(), acp_thread: acp_thread.downgrade(), _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); @@ -882,8 +901,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection { ); })?; - // we need to actually deserialize the DbThread. - // todo!() + let events = thread.update(cx, |thread, cx| thread.replay(cx))?; + cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))? + .await?; Ok(acp_thread) }) diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 87f4803daf..6d688675ba 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -12,7 +12,7 @@ use futures::{ channel::{mpsc, oneshot}, stream::FuturesUnordered, }; -use gpui::{App, Context, Entity, SharedString, Task}; +use gpui::{App, AppContext, Context, Entity, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, @@ -545,7 +545,10 @@ impl Thread { } } - pub fn replay(&self, cx: &mut Context) -> mpsc::UnboundedReceiver> { + pub fn replay( + &mut self, + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { let (tx, rx) = mpsc::unbounded(); let stream = ThreadEventStream(tx); for message in &self.messages { @@ -615,16 +618,15 @@ impl Thread { ); tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) .log_err(); - } else { - stream.update_tool_call_fields( - &tool_use.id, - acp::ToolCallUpdateFields { - content: Some(vec![TOOL_CANCELED_MESSAGE.into()]), - status: Some(acp::ToolCallStatus::Failed), - ..Default::default() - }, - ); } + + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + ..Default::default() + }, + ); } pub fn project(&self) -> &Entity { @@ -1744,6 +1746,26 @@ impl From for UserMessageContent { } } +impl From for acp::ContentBlock { + fn from(content: UserMessageContent) -> Self { + match content { + UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + }), + UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent { + data: image.source.to_string(), + mime_type: "image/png".to_string(), + annotations: None, + uri: None, + }), + UserMessageContent::Mention { uri, content } => { + todo!() + } + } + } +} + fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { LanguageModelImage { source: image_content.data.into(), From 199256e43ec88edb30e9fc9269a313f6f2ff0584 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 14:21:01 +0200 Subject: [PATCH 13/25] Checkpoint --- crates/agent2/src/thread.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 6d688675ba..034b26b714 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -607,10 +607,10 @@ impl Thread { let kind = tool.kind(); stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); - if let Some(output) = tool_result + let output = tool_result .as_ref() - .and_then(|result| result.output.clone()) - { + .and_then(|result| result.output.clone()); + if let Some(output) = output.clone() { let tool_event_stream = ToolCallEventStream::new( tool_use.id.clone(), stream.clone(), @@ -624,6 +624,7 @@ impl Thread { &tool_use.id, acp::ToolCallUpdateFields { status: Some(acp::ToolCallStatus::Completed), + raw_output: output, ..Default::default() }, ); From e6e23d04f8e40afc0b1dd122713b0c1674ae4d70 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 14:56:10 +0200 Subject: [PATCH 14/25] Checkpoint --- crates/acp_thread/src/acp_thread.rs | 12 +++- crates/acp_thread/src/diff.rs | 13 ++-- crates/agent2/src/agent.rs | 3 +- crates/agent2/src/tools/edit_file_tool.rs | 82 +++++++++++++++------- crates/agent2/src/tools/web_search_tool.rs | 67 +++++++++++------- 5 files changed, 113 insertions(+), 64 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 0fb0d9e779..a0e62c29e3 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -539,9 +539,15 @@ impl ToolCallContent { acp::ToolCallContent::Content { content } => { Self::ContentBlock(ContentBlock::new(content, &language_registry, cx)) } - acp::ToolCallContent::Diff { diff } => { - Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx))) - } + acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| { + Diff::finalized( + diff.path, + diff.old_text, + diff.new_text, + language_registry, + cx, + ) + })), } } diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs index a2c2d6c322..a67e37bcb8 100644 --- a/crates/acp_thread/src/diff.rs +++ b/crates/acp_thread/src/diff.rs @@ -1,4 +1,3 @@ -use agent_client_protocol as acp; use anyhow::Result; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::{MultiBuffer, PathKey}; @@ -21,17 +20,13 @@ pub enum Diff { } impl Diff { - pub fn from_acp( - diff: acp::Diff, + pub fn finalized( + path: PathBuf, + old_text: Option, + new_text: String, language_registry: Arc, cx: &mut Context, ) -> Self { - let acp::Diff { - path, - old_text, - new_text, - } = diff; - let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 5b27f0a048..398054739e 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -587,11 +587,12 @@ impl NativeAgentConnection { action_log: Entity, cx: &mut Context, ) { + let language_registry = project.read(cx).languages().clone(); thread.add_tool(CopyPathTool::new(project.clone())); thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); thread.add_tool(DiagnosticsTool::new(project.clone())); - thread.add_tool(EditFileTool::new(cx.weak_entity())); + thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry)); thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(GrepTool::new(project.clone())); diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 62774ac2b1..01fa77e22d 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -7,8 +7,8 @@ use cloud_llm_client::CompletionIntent; use collections::HashSet; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use indoc::formatdoc; -use language::ToPoint; use language::language_settings::{self, FormatOnSave}; +use language::{LanguageRegistry, ToPoint}; use language_model::LanguageModelToolResultContent; use paths; use project::lsp_store::{FormatTrigger, LspFormatTarget}; @@ -98,11 +98,13 @@ pub enum EditFileMode { #[derive(Debug, Serialize, Deserialize)] pub struct EditFileToolOutput { + #[serde(alias = "original_path")] input_path: PathBuf, - project_path: PathBuf, new_text: String, old_text: Arc, + #[serde(default)] diff: String, + #[serde(alias = "raw_output")] edit_agent_output: EditAgentOutput, } @@ -123,11 +125,15 @@ impl From for LanguageModelToolResultContent { pub struct EditFileTool { thread: WeakEntity, + language_registry: Arc, } impl EditFileTool { - pub fn new(thread: WeakEntity) -> Self { - Self { thread } + pub fn new(thread: WeakEntity, language_registry: Arc) -> Self { + Self { + thread, + language_registry, + } } fn authorize( @@ -419,7 +425,6 @@ impl AgentTool for EditFileTool { Ok(EditFileToolOutput { input_path: input.path, - project_path: project_path.path.to_path_buf(), new_text: new_text.clone(), old_text, diff: unified_diff, @@ -427,6 +432,26 @@ impl AgentTool for EditFileTool { }) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + dbg!(&output); + event_stream.update_diff(cx.new(|cx| { + Diff::finalized( + output.input_path, + Some(output.old_text.to_string()), + output.new_text, + self.language_registry.clone(), + cx, + ) + })); + Ok(()) + } } /// Validate that the file path is valid, meaning: @@ -515,6 +540,7 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -537,7 +563,7 @@ mod tests { path: "root/nonexistent_file.txt".into(), mode: EditFileMode::Edit, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( input, ToolCallEventStream::test().0, cx, @@ -754,11 +780,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( - input, - ToolCallEventStream::test().0, - cx, - ) + Arc::new(EditFileTool::new( + thread.downgrade(), + language_registry.clone(), + )) + .run(input, ToolCallEventStream::test().0, cx) }); // Stream the unformatted content @@ -811,7 +837,7 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( input, ToolCallEventStream::test().0, cx, @@ -857,6 +883,7 @@ mod tests { .unwrap(); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -896,11 +923,11 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( - input, - ToolCallEventStream::test().0, - cx, - ) + Arc::new(EditFileTool::new( + thread.downgrade(), + language_registry.clone(), + )) + .run(input, ToolCallEventStream::test().0, cx) }); // Stream the content with trailing whitespace @@ -948,7 +975,7 @@ mod tests { path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, }; - Arc::new(EditFileTool::new(thread.downgrade())).run( + Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run( input, ToolCallEventStream::test().0, cx, @@ -985,6 +1012,7 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -1000,7 +1028,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); fs.insert_tree("/root", json!({})).await; // Test 1: Path with .zed component should require confirmation @@ -1122,6 +1150,7 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); @@ -1137,7 +1166,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test global config paths - these should require confirmation if they exist and are outside the project let test_cases = vec![ @@ -1231,7 +1260,7 @@ mod tests { cx, ) .await; - + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1247,7 +1276,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test files in different worktrees let test_cases = vec![ @@ -1313,6 +1342,7 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1328,7 +1358,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test edge cases let test_cases = vec![ @@ -1397,6 +1427,7 @@ mod tests { ) .await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1412,7 +1443,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test different EditFileMode values let modes = vec![ @@ -1478,6 +1509,7 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -1493,7 +1525,7 @@ mod tests { cx, ) }); - let tool = Arc::new(EditFileTool::new(thread.downgrade())); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); assert_eq!( tool.initial_title(Err(json!({ diff --git a/crates/agent2/src/tools/web_search_tool.rs b/crates/agent2/src/tools/web_search_tool.rs index c1c0970742..d71a128bfe 100644 --- a/crates/agent2/src/tools/web_search_tool.rs +++ b/crates/agent2/src/tools/web_search_tool.rs @@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool { } }; - let result_text = if response.results.len() == 1 { - "1 result".to_string() - } else { - format!("{} results", response.results.len()) - }; - event_stream.update_fields(acp::ToolCallUpdateFields { - title: Some(format!("Searched the web: {result_text}")), - content: Some( - response - .results - .iter() - .map(|result| acp::ToolCallContent::Content { - content: acp::ContentBlock::ResourceLink(acp::ResourceLink { - name: result.title.clone(), - uri: result.url.clone(), - title: Some(result.title.clone()), - description: Some(result.text.clone()), - mime_type: None, - annotations: None, - size: None, - }), - }) - .collect(), - ), - ..Default::default() - }); + emit_update(&response, &event_stream); Ok(WebSearchToolOutput(response)) }) } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + emit_update(&output.0, &event_stream); + Ok(()) + } +} + +fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) { + let result_text = if response.results.len() == 1 { + "1 result".to_string() + } else { + format!("{} results", response.results.len()) + }; + event_stream.update_fields(acp::ToolCallUpdateFields { + title: Some(format!("Searched the web: {result_text}")), + content: Some( + response + .results + .iter() + .map(|result| acp::ToolCallContent::Content { + content: acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: result.title.clone(), + uri: result.url.clone(), + title: Some(result.title.clone()), + description: Some(result.text.clone()), + mime_type: None, + annotations: None, + size: None, + }), + }) + .collect(), + ), + ..Default::default() + }); } From 205b1371aad3956073259db45bc12fc70e553671 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 15:08:12 +0200 Subject: [PATCH 15/25] Synchronize initial entries --- crates/agent_ui/src/acp/thread_view.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 4dd0a3b7f6..38036ad3c4 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -308,8 +308,13 @@ impl AcpThreadView { let action_log_subscription = cx.observe(&action_log, |_, _, cx| cx.notify()); - this.list_state - .splice(0..0, thread.read(cx).entries().len()); + let count = thread.read(cx).entries().len(); + this.list_state.splice(0..0, count); + this.entry_view_state.update(cx, |view_state, cx| { + for ix in 0..count { + view_state.sync_entry(ix, &thread, window, cx); + } + }); AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); From d83210d978b7bb1fe11925c0cc772d405f79e003 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 17:30:12 +0200 Subject: [PATCH 16/25] WIP Co-authored-by: Conrad Irwin --- Cargo.lock | 1 + crates/acp_thread/src/connection.rs | 5 +- crates/agent2/Cargo.toml | 2 + crates/agent2/src/agent.rs | 151 ++++++++------ crates/agent2/src/agent2.rs | 6 + crates/agent2/src/tests/mod.rs | 5 +- crates/agent2/src/thread.rs | 242 ++++++++++++++++++---- crates/agent2/src/tools/edit_file_tool.rs | 11 +- 8 files changed, 319 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5479fddab3..4141111073 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -211,6 +211,7 @@ dependencies = [ "env_logger 0.11.8", "fs", "futures 0.3.31", + "git", "gpui", "gpui_tokio", "handlebars 4.5.0", diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index fe6d3169bd..398222a831 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -27,7 +27,10 @@ pub trait AgentConnection { cx: &mut App, ) -> Task>>; - fn list_threads(&self, _cx: &mut App) -> Option>> { + fn list_threads( + &self, + _cx: &mut App, + ) -> Option>>> { return None; } diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index ed487e2e22..88da875930 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -27,6 +27,7 @@ collections.workspace = true context_server.workspace = true fs.workspace = true futures.workspace = true +git.workspace = true gpui.workspace = true handlebars = { workspace = true, features = ["rust-embed"] } html_to_markdown.workspace = true @@ -72,6 +73,7 @@ context_server = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } +git = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } gpui_tokio.workspace = true language = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 398054739e..403a59e51b 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -5,7 +5,7 @@ use crate::{ OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, }; -use crate::{DbThread, ThreadId, ThreadsDatabase}; +use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; @@ -44,6 +44,8 @@ const RULES_FILE_NAMES: [&'static str; 9] = [ "GEMINI.md", ]; +const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500); + pub struct RulesLoadingError { pub message: SharedString, } @@ -54,7 +56,8 @@ struct Session { thread: Entity, /// The ACP thread that handles protocol communication acp_thread: WeakEntity, - _subscription: Subscription, + save_task: Task>, + _subscriptions: Vec, } pub struct LanguageModels { @@ -169,8 +172,9 @@ pub struct NativeAgent { models: LanguageModels, project: Entity, prompt_store: Option>, - thread_database: Shared, Arc>>>, - history_listeners: Vec>>, + thread_database: Arc, + history: watch::Sender>>, + load_history: Task>, fs: Arc, _subscriptions: Vec, } @@ -189,6 +193,11 @@ impl NativeAgent { .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))? .await; + let thread_database = cx + .update(|cx| ThreadsDatabase::connect(cx))? + .await + .map_err(|e| anyhow!(e))?; + cx.new(|cx| { let mut subscriptions = vec![ cx.subscribe(&project, Self::handle_project_event), @@ -203,7 +212,7 @@ impl NativeAgent { let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = watch::channel(()); - Self { + let this = Self { sessions: HashMap::new(), project_context: Rc::new(RefCell::new(project_context)), project_context_needs_refresh: project_context_needs_refresh_tx, @@ -213,18 +222,85 @@ impl NativeAgent { context_server_registry: cx.new(|cx| { ContextServerRegistry::new(project.read(cx).context_server_store(), cx) }), - thread_database: ThreadsDatabase::connect(cx), + thread_database, templates, models: LanguageModels::new(cx), project, prompt_store, fs, - history_listeners: Vec::new(), + history: watch::channel(None).0, + load_history: Task::ready(Ok(())), _subscriptions: subscriptions, - } + }; + this.reload_history(cx); + this }) } + pub fn insert_session( + &mut self, + thread: Entity, + acp_thread: Entity, + cx: &mut Context, + ) { + let id = thread.read(cx).id().clone(); + self.sessions.insert( + id, + Session { + thread: thread.clone(), + acp_thread: acp_thread.downgrade(), + save_task: Task::ready(()), + _subscriptions: vec![ + cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }), + cx.observe(&thread, |this, thread, cx| { + this.save_thread(thread.clone(), cx) + }), + ], + }, + ); + } + + fn save_thread(&mut self, thread: Entity, cx: &mut Context) { + let id = thread.read(cx).id().clone(); + let Some(session) = self.sessions.get_mut(&id) else { + return; + }; + + let thread = thread.downgrade(); + let thread_database = self.thread_database.clone(); + session.save_task = cx.spawn(async move |this, cx| { + cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; + let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await; + thread_database.save_thread(id, db_thread).await?; + this.update(cx, |this, cx| this.reload_history(cx))?; + Ok(()) + }); + } + + fn reload_history(&mut self, cx: &mut Context) { + let thread_database = self.thread_database.clone(); + self.load_history = cx.spawn(async move |this, cx| { + let results = cx + .background_spawn(async move { + let results = thread_database.list_threads().await?; + Ok(results + .into_iter() + .map(|thread| AcpThreadMetadata { + agent: NATIVE_AGENT_SERVER_NAME.clone(), + id: thread.id.into(), + title: thread.title, + updated_at: thread.updated_at, + }) + .collect()) + }) + .await?; + this.update(cx, |this, cx| this.history.send(Some(results)))?; + anyhow::Ok(()) + }); + } + pub fn models(&self) -> &LanguageModels { &self.models } @@ -699,7 +775,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::debug!("Starting thread creation in async context"); // Generate session ID - let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + let session_id = generate_session_id(); log::info!("Created session with ID: {}", session_id); // Create AcpThread @@ -743,6 +819,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let thread = cx.new(|cx| { let mut thread = Thread::new( + session_id.clone(), project.clone(), agent.project_context.clone(), agent.context_server_registry.clone(), @@ -761,16 +838,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Store the session agent.update(cx, |agent, cx| { - agent.sessions.insert( - session_id, - Session { - thread, - acp_thread: acp_thread.downgrade(), - _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { - this.sessions.remove(acp_thread.session_id()); - }), - }, - ); + agent.insert_session(thread, acp_thread.clone(), cx) })?; Ok(acp_thread) @@ -785,35 +853,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Task::ready(Ok(())) } - fn list_threads(&self, cx: &mut App) -> Option>> { - let (mut tx, rx) = futures::channel::mpsc::unbounded(); - let database = self.0.update(cx, |this, _| { - this.history_listeners.push(tx.clone()); - this.thread_database.clone() - }); - cx.background_executor() - .spawn(async move { - dbg!("listing!"); - let database = database.await.map_err(|e| anyhow!(e))?; - let results = database.list_threads().await?; - - dbg!(&results); - tx.send( - results - .into_iter() - .map(|thread| AcpThreadMetadata { - agent: NATIVE_AGENT_SERVER_NAME.clone(), - id: thread.id.into(), - title: thread.title, - updated_at: thread.updated_at, - }) - .collect(), - ) - .await?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - Some(rx) + fn list_threads( + &self, + cx: &mut App, + ) -> Option>>> { + Some(self.0.read(cx).history.receiver()) } fn load_thread( @@ -890,16 +934,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Store the session agent.update(cx, |agent, cx| { - agent.sessions.insert( - session_id, - Session { - thread: thread.clone(), - acp_thread: acp_thread.downgrade(), - _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { - this.sessions.remove(acp_thread.session_id()); - }), - }, - ); + agent.insert_session(session_id, thread, acp_thread, cx) })?; let events = thread.update(cx, |thread, cx| thread.replay(cx))?; diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index 1813fe1880..eee9810cef 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -15,3 +15,9 @@ pub use native_agent_server::NativeAgentServer; pub use templates::*; pub use thread::*; pub use tools::*; + +use agent_client_protocol as acp; + +pub fn generate_session_id() -> acp::SessionId { + acp::SessionId(uuid::Uuid::new_v4().to_string().into()) +} diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 5cbddafded..75a21a2baa 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -709,9 +709,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { ); } -async fn expect_tool_call( - events: &mut UnboundedReceiver>, -) -> acp::ToolCall { +async fn expect_tool_call(events: &mut UnboundedReceiver>) -> acp::ToolCall { let event = events .next() .await @@ -1501,6 +1499,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, project_context.clone(), context_server_registry, diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 034b26b714..ec820c7b5f 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,25 +1,35 @@ -use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates}; +use crate::{ + ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates, +}; use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; +use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot}; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; +use chrono::{DateTime, Utc}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus}; use collections::IndexMap; use fs::Fs; use futures::{ + FutureExt, channel::{mpsc, oneshot}, + future::Shared, stream::FuturesUnordered, }; +use git::repository::DiffType; use gpui::{App, AppContext, Context, Entity, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, + LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage, +}; +use project::{ + Project, + git_store::{GitStore, RepositoryState}, }; -use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; @@ -32,41 +42,6 @@ use uuid::Uuid; const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, -)] -pub struct ThreadId(pub(crate) Arc); - -impl ThreadId { - pub fn new() -> Self { - Self(Uuid::new_v4().to_string().into()) - } -} - -impl std::fmt::Display for ThreadId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From<&str> for ThreadId { - fn from(value: &str) -> Self { - Self(value.into()) - } -} - -impl From for ThreadId { - fn from(value: acp::SessionId) -> Self { - Self(value.0) - } -} - -impl From for acp::SessionId { - fn from(value: ThreadId) -> Self { - Self(value.0) - } -} - /// The ID of the user prompt that initiated a request. /// /// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). @@ -461,9 +436,28 @@ pub struct ToolCallAuthorization { pub response: oneshot::Sender, } +enum ThreadTitle { + None, + Pending(Task<()>), + Done(Result), +} + +impl ThreadTitle { + pub fn unwrap_or_default(&self) -> SharedString { + if let ThreadTitle::Done(Ok(title)) = self { + title.clone() + } else { + "New Thread".into() + } + } +} + pub struct Thread { - id: ThreadId, + id: acp::SessionId, prompt_id: PromptId, + updated_at: DateTime, + title: ThreadTitle, + summary: DetailedSummaryState, messages: Vec, completion_mode: CompletionMode, /// Holds the task that handles agent interaction until the end of the turn. @@ -473,6 +467,9 @@ pub struct Thread { pending_message: Option, tools: BTreeMap>, tool_use_limit_reached: bool, + request_token_usage: Vec, + cumulative_token_usage: TokenUsage, + initial_project_snapshot: Shared>>>, context_server_registry: Entity, profile_id: AgentProfileId, project_context: Rc>, @@ -484,6 +481,7 @@ pub struct Thread { impl Thread { pub fn new( + id: acp::SessionId, project: Entity, project_context: Rc>, context_server_registry: Entity, @@ -494,14 +492,25 @@ impl Thread { ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { - id: ThreadId::new(), + id, prompt_id: PromptId::new(), + updated_at: Utc::now(), + title: ThreadTitle::None, + summary: DetailedSummaryState::default(), messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, pending_message: None, tools: BTreeMap::default(), tool_use_limit_reached: false, + request_token_usage: Vec::new(), + cumulative_token_usage: TokenUsage::default(), + initial_project_snapshot: { + let project_snapshot = Self::project_snapshot(project.clone(), cx); + cx.foreground_executor() + .spawn(async move { Some(project_snapshot.await) }) + .shared() + }, context_server_registry, profile_id, project_context, @@ -512,8 +521,12 @@ impl Thread { } } + pub fn id(&self) -> &acp::SessionId { + &self.id + } + pub fn from_db( - id: ThreadId, + id: acp::SessionId, db_thread: DbThread, project: Entity, project_context: Rc>, @@ -529,12 +542,17 @@ impl Thread { Self { id, prompt_id: PromptId::new(), + title: ThreadTitle::Done(Ok(db_thread.title.clone())), + summary: db_thread.summary, messages: db_thread.messages, completion_mode: CompletionMode::Normal, running_turn: None, pending_message: None, tools: BTreeMap::default(), tool_use_limit_reached: false, + request_token_usage: db_thread.request_token_usage.clone(), + cumulative_token_usage: db_thread.cumulative_token_usage.clone(), + initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(), context_server_registry, profile_id, project_context, @@ -542,9 +560,35 @@ impl Thread { model, project, action_log, + updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list) } } + pub fn to_db(&self, cx: &App) -> Task { + let initial_project_snapshot = self.initial_project_snapshot.clone(); + let mut thread = DbThread { + title: self.title.unwrap_or_default(), + messages: self.messages.clone(), + updated_at: self.updated_at.clone(), + summary: self.summary.clone(), + initial_project_snapshot: None, + cumulative_token_usage: self.cumulative_token_usage.clone(), + request_token_usage: self.request_token_usage.clone(), + model: Some(DbLanguageModel { + provider: self.model.provider_id().to_string(), + model: self.model.name().0.to_string(), + }), + completion_mode: Some(self.completion_mode.into()), + profile: Some(self.profile_id.clone()), + }; + + cx.background_spawn(async move { + let initial_project_snapshot = initial_project_snapshot.await; + thread.initial_project_snapshot = initial_project_snapshot; + thread + }) + } + pub fn replay( &mut self, cx: &mut Context, @@ -630,6 +674,122 @@ impl Thread { ); } + /// Create a snapshot of the current project state including git information and unsaved buffers. + fn project_snapshot( + project: Entity, + cx: &mut Context, + ) -> Task> { + let git_store = project.read(cx).git_store().clone(); + let worktree_snapshots: Vec<_> = project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) + .collect(); + + cx.spawn(async move |_, cx| { + let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; + + let mut unsaved_buffers = Vec::new(); + cx.update(|app_cx| { + let buffer_store = project.read(app_cx).buffer_store(); + for buffer_handle in buffer_store.read(app_cx).buffers() { + let buffer = buffer_handle.read(app_cx); + if buffer.is_dirty() { + if let Some(file) = buffer.file() { + let path = file.path().to_string_lossy().to_string(); + unsaved_buffers.push(path); + } + } + } + }) + .ok(); + + Arc::new(ProjectSnapshot { + worktree_snapshots, + unsaved_buffer_paths: unsaved_buffers, + timestamp: Utc::now(), + }) + }) + } + + fn worktree_snapshot( + worktree: Entity, + git_store: Entity, + cx: &App, + ) -> Task { + cx.spawn(async move |cx| { + // Get worktree path and snapshot + let worktree_info = cx.update(|app_cx| { + let worktree = worktree.read(app_cx); + let path = worktree.abs_path().to_string_lossy().to_string(); + let snapshot = worktree.snapshot(); + (path, snapshot) + }); + + let Ok((worktree_path, _snapshot)) = worktree_info else { + return WorktreeSnapshot { + worktree_path: String::new(), + git_state: None, + }; + }; + + let git_state = git_store + .update(cx, |git_store, cx| { + git_store + .repositories() + .values() + .find(|repo| { + repo.read(cx) + .abs_path_to_repo_path(&worktree.read(cx).abs_path()) + .is_some() + }) + .cloned() + }) + .ok() + .flatten() + .map(|repo| { + repo.update(cx, |repo, _| { + let current_branch = + repo.branch.as_ref().map(|branch| branch.name().to_owned()); + repo.send_job(None, |state, _| async move { + let RepositoryState::Local { backend, .. } = state else { + return GitState { + remote_url: None, + head_sha: None, + current_branch, + diff: None, + }; + }; + + let remote_url = backend.remote_url("origin"); + let head_sha = backend.head_sha().await; + let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); + + GitState { + remote_url, + head_sha, + current_branch, + diff, + } + }) + }) + }); + + let git_state = match git_state { + Some(git_state) => match git_state.ok() { + Some(git_state) => git_state.await.ok(), + None => None, + }, + None => None, + }; + + WorktreeSnapshot { + worktree_path, + git_state, + } + }) + } + pub fn project(&self) -> &Entity { &self.project } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 01fa77e22d..c320e8ea72 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -522,7 +522,7 @@ fn resolve_path( #[cfg(test)] mod tests { use super::*; - use crate::{ContextServerRegistry, Templates}; + use crate::{ContextServerRegistry, Templates, generate_session_id}; use action_log::ActionLog; use client::TelemetrySettings; use fs::Fs; @@ -547,6 +547,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -748,6 +749,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -890,6 +892,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -1019,6 +1022,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -1157,6 +1161,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -1267,6 +1272,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project.clone(), Rc::default(), context_server_registry.clone(), @@ -1349,6 +1355,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project.clone(), Rc::default(), context_server_registry.clone(), @@ -1434,6 +1441,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project.clone(), Rc::default(), context_server_registry.clone(), @@ -1516,6 +1524,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project.clone(), Rc::default(), context_server_registry, From 4b1a48e4de6255e434bd8efe942eef3abf354013 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 18 Aug 2025 10:40:15 -0600 Subject: [PATCH 17/25] Wire up history completely Co-authored-by: Antonio Scandurra --- crates/acp_thread/src/connection.rs | 3 +- crates/agent2/src/agent.rs | 69 ++++++++++++++------------ crates/agent2/src/db.rs | 8 +-- crates/agent2/src/history_store.rs | 58 +++++++--------------- crates/agent2/src/tests/mod.rs | 4 +- crates/agent2/src/thread.rs | 32 ++++++------ crates/agent_ui/src/acp/thread_view.rs | 17 ++++--- 7 files changed, 93 insertions(+), 98 deletions(-) diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 398222a831..94b5fe015a 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -2,7 +2,6 @@ use crate::{AcpThread, AcpThreadMetadata}; use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; -use futures::channel::mpsc::UnboundedReceiver; use gpui::{Entity, SharedString, Task}; use project::Project; use serde::{Deserialize, Serialize}; @@ -27,6 +26,8 @@ pub trait AgentConnection { cx: &mut App, ) -> Task>>; + // todo!(expose a history trait, and include list_threads and load_thread) + // todo!(write a test) fn list_threads( &self, _cx: &mut App, diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 403a59e51b..6de5445d80 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -5,16 +5,15 @@ use crate::{ OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, }; -use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id}; +use crate::{ThreadsDatabase, generate_session_id}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; use collections::{HashSet, IndexMap}; use fs::Fs; -use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; -use futures::future::Shared; -use futures::{SinkExt, StreamExt, future}; +use futures::channel::mpsc; +use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; @@ -30,6 +29,7 @@ use std::collections::HashMap; use std::path::Path; use std::rc::Rc; use std::sync::Arc; +use std::time::Duration; use util::ResultExt; const RULES_FILE_NAMES: [&'static str; 9] = [ @@ -174,7 +174,7 @@ pub struct NativeAgent { prompt_store: Option>, thread_database: Arc, history: watch::Sender>>, - load_history: Task>, + load_history: Task<()>, fs: Arc, _subscriptions: Vec, } @@ -212,7 +212,7 @@ impl NativeAgent { let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = watch::channel(()); - let this = Self { + let mut this = Self { sessions: HashMap::new(), project_context: Rc::new(RefCell::new(project_context)), project_context_needs_refresh: project_context_needs_refresh_tx, @@ -229,7 +229,7 @@ impl NativeAgent { prompt_store, fs, history: watch::channel(None).0, - load_history: Task::ready(Ok(())), + load_history: Task::ready(()), _subscriptions: subscriptions, }; this.reload_history(cx); @@ -249,7 +249,7 @@ impl NativeAgent { Session { thread: thread.clone(), acp_thread: acp_thread.downgrade(), - save_task: Task::ready(()), + save_task: Task::ready(Ok(())), _subscriptions: vec![ cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); @@ -280,24 +280,30 @@ impl NativeAgent { } fn reload_history(&mut self, cx: &mut Context) { + dbg!(""); let thread_database = self.thread_database.clone(); self.load_history = cx.spawn(async move |this, cx| { let results = cx .background_spawn(async move { let results = thread_database.list_threads().await?; - Ok(results - .into_iter() - .map(|thread| AcpThreadMetadata { - agent: NATIVE_AGENT_SERVER_NAME.clone(), - id: thread.id.into(), - title: thread.title, - updated_at: thread.updated_at, - }) - .collect()) + dbg!(&results); + anyhow::Ok( + results + .into_iter() + .map(|thread| AcpThreadMetadata { + agent: NATIVE_AGENT_SERVER_NAME.clone(), + id: thread.id.into(), + title: thread.title, + updated_at: thread.updated_at, + }) + .collect(), + ) }) - .await?; - this.update(cx, |this, cx| this.history.send(Some(results)))?; - anyhow::Ok(()) + .await; + if let Some(results) = results.log_err() { + this.update(cx, |this, _| this.history.send(Some(results))) + .ok(); + } }); } @@ -509,10 +515,10 @@ impl NativeAgent { ) { self.models.refresh_list(cx); for session in self.sessions.values_mut() { - session.thread.update(cx, |thread, _| { + session.thread.update(cx, |thread, cx| { let model_id = LanguageModels::model_id(&thread.model()); if let Some(model) = self.models.model_from_id(&model_id) { - thread.set_model(model.clone()); + thread.set_model(model.clone(), cx); } }); } @@ -715,8 +721,8 @@ impl AgentModelSelector for NativeAgentConnection { return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); }; - thread.update(cx, |thread, _cx| { - thread.set_model(model.clone()); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); }); update_settings_file::( @@ -867,12 +873,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, ) -> Task>> { - let thread_id = ThreadId::from(session_id.clone()); let database = self.0.update(cx, |this, _| this.thread_database.clone()); cx.spawn(async move |cx| { - let database = database.await.map_err(|e| anyhow!(e))?; let db_thread = database - .load_thread(thread_id.clone()) + .load_thread(session_id.clone()) .await? .context("no such thread found")?; @@ -915,7 +919,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let thread = cx.new(|cx| { let mut thread = Thread::from_db( - thread_id, + session_id, db_thread, project.clone(), agent.project_context.clone(), @@ -934,7 +938,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Store the session agent.update(cx, |agent, cx| { - agent.insert_session(session_id, thread, acp_thread, cx) + agent.insert_session(thread.clone(), acp_thread.clone(), cx) })?; let events = thread.update(cx, |thread, cx| thread.replay(cx))?; @@ -995,7 +999,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::info!("Cancelling on session: {}", session_id); self.0.update(cx, |agent, cx| { if let Some(agent) = agent.sessions.get(session_id) { - agent.thread.update(cx, |thread, _cx| thread.cancel()); + agent.thread.update(cx, |thread, cx| thread.cancel(cx)); } }); } @@ -1022,7 +1026,10 @@ struct NativeAgentSessionEditor(Entity); impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { - Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) + Task::ready( + self.0 + .update(cx, |thread, cx| thread.truncate(message_id, cx)), + ) } } diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index 5332da276f..a7240df5c7 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -1,4 +1,4 @@ -use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent}; +use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; use agent::thread_store; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; @@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbThreadMetadata { - pub id: ThreadId, + pub id: acp::SessionId, #[serde(alias = "summary")] pub title: SharedString, pub updated_at: DateTime, @@ -323,7 +323,7 @@ impl ThreadsDatabase { for (id, summary, updated_at) in rows { threads.push(DbThreadMetadata { - id: ThreadId(id), + id: acp::SessionId(id), title: summary.into(), updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), }); @@ -333,7 +333,7 @@ impl ThreadsDatabase { }) } - pub fn load_thread(&self, id: ThreadId) -> Task>> { + pub fn load_thread(&self, id: acp::SessionId) -> Task>> { let connection = self.connection.clone(); self.executor.spawn(async move { diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs index fb4f34f9c5..f4e53c4c23 100644 --- a/crates/agent2/src/history_store.rs +++ b/crates/agent2/src/history_store.rs @@ -1,17 +1,13 @@ use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName}; -use agent::{ThreadId, thread_store::ThreadStore}; use agent_client_protocol as acp; use anyhow::{Context as _, Result}; use assistant_context::SavedContextMetadata; use chrono::{DateTime, Utc}; use collections::HashMap; -use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*}; -use itertools::Itertools; -use paths::contexts_dir; +use gpui::{SharedString, Task, prelude::*}; use serde::{Deserialize, Serialize}; use smol::stream::StreamExt; -use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration}; -use util::ResultExt as _; +use std::{path::Path, sync::Arc, time::Duration}; const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json"; @@ -64,16 +60,16 @@ enum SerializedRecentOpen { } pub struct AgentHistory { - entries: HashMap, - _task: Task>, + entries: watch::Receiver>>, + _task: Task<()>, } pub struct HistoryStore { - agents: HashMap, + agents: HashMap, // todo!() text threads } impl HistoryStore { - pub fn new(cx: &mut Context) -> Self { + pub fn new(_cx: &mut Context) -> Self { Self { agents: HashMap::default(), } @@ -88,33 +84,18 @@ impl HistoryStore { let Some(mut history) = connection.list_threads(cx) else { return; }; - let task = cx.spawn(async move |this, cx| { - while let Some(updated_history) = history.next().await { - dbg!(&updated_history); - this.update(cx, |this, cx| { - for entry in updated_history { - let agent = this - .agents - .get_mut(&entry.agent) - .context("agent not found")?; - agent.entries.insert(entry.id.clone(), entry); - } - cx.notify(); - anyhow::Ok(()) - })?? - } - Ok(()) - }); - self.agents.insert( - agent_name, - AgentHistory { - entries: Default::default(), - _task: task, - }, - ); + let history = AgentHistory { + entries: history.clone(), + _task: cx.spawn(async move |this, cx| { + while history.changed().await.is_ok() { + this.update(cx, |_, cx| cx.notify()).ok(); + } + }), + }; + self.agents.insert(agent_name.clone(), history); } - pub fn entries(&self, cx: &mut Context) -> Vec { + pub fn entries(&mut self, _cx: &mut Context) -> Vec { let mut history_entries = Vec::new(); #[cfg(debug_assertions)] @@ -124,9 +105,8 @@ impl HistoryStore { history_entries.extend( self.agents - .values() - .flat_map(|agent| agent.entries.values()) - .cloned() + .values_mut() + .flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?") .map(HistoryEntry::Thread), ); // todo!() include the text threads in here. @@ -135,7 +115,7 @@ impl HistoryStore { history_entries } - pub fn recent_entries(&self, limit: usize, cx: &mut Context) -> Vec { + pub fn recent_entries(&mut self, limit: usize, cx: &mut Context) -> Vec { self.entries(cx).into_iter().take(limit).collect() } } diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 75a21a2baa..2a4d306290 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -938,7 +938,7 @@ async fn test_cancellation(cx: &mut TestAppContext) { // Cancel the current send and ensure that the event stream is closed, even // if one of the tools is still running. - thread.update(cx, |thread, _cx| thread.cancel()); + thread.update(cx, |thread, cx| thread.cancel(cx)); let events = events.collect::>().await; let last_event = events.last(); assert!( @@ -1113,7 +1113,7 @@ async fn test_truncate(cx: &mut TestAppContext) { }); thread - .update(cx, |thread, _cx| thread.truncate(message_id)) + .update(cx, |thread, cx| thread.truncate(message_id, cx)) .unwrap(); cx.run_until_parked(); thread.read_with(cx, |thread, _| { diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index ec820c7b5f..7ea5ff7cc6 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -802,16 +802,18 @@ impl Thread { &self.model } - pub fn set_model(&mut self, model: Arc) { + pub fn set_model(&mut self, model: Arc, cx: &mut Context) { self.model = model; + cx.notify() } pub fn completion_mode(&self) -> CompletionMode { self.completion_mode } - pub fn set_completion_mode(&mut self, mode: CompletionMode) { + pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context) { self.completion_mode = mode; + cx.notify() } #[cfg(any(test, feature = "test-support"))] @@ -839,21 +841,22 @@ impl Thread { self.profile_id = profile_id; } - pub fn cancel(&mut self) { + pub fn cancel(&mut self, cx: &mut Context) { if let Some(running_turn) = self.running_turn.take() { running_turn.cancel(); } - self.flush_pending_message(); + self.flush_pending_message(cx); } - pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { - self.cancel(); + pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { + self.cancel(cx); let Some(position) = self.messages.iter().position( |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), ) else { return Err(anyhow!("Message not found")); }; self.messages.truncate(position); + cx.notify(); Ok(()) } @@ -900,7 +903,7 @@ impl Thread { } fn run_turn(&mut self, cx: &mut Context) -> mpsc::UnboundedReceiver> { - self.cancel(); + self.cancel(cx); let model = self.model.clone(); let (events_tx, events_rx) = mpsc::unbounded::>(); @@ -938,8 +941,8 @@ impl Thread { LanguageModelCompletionEvent::Stop(reason) => { event_stream.send_stop(reason); if reason == StopReason::Refusal { - this.update(cx, |this, _cx| { - this.flush_pending_message(); + this.update(cx, |this, cx| { + this.flush_pending_message(cx); this.messages.truncate(message_ix); })?; return Ok(()); @@ -991,7 +994,7 @@ impl Thread { log::info!("No tool uses found, completing turn"); return Ok(()); } else { - this.update(cx, |this, _| this.flush_pending_message())?; + this.update(cx, |this, cx| this.flush_pending_message(cx))?; completion_intent = CompletionIntent::ToolResults; } } @@ -1005,8 +1008,8 @@ impl Thread { log::info!("Turn execution completed successfully"); } - this.update(cx, |this, _| { - this.flush_pending_message(); + this.update(cx, |this, cx| { + this.flush_pending_message(cx); this.running_turn.take(); }) .ok(); @@ -1046,7 +1049,7 @@ impl Thread { match event { StartMessage { .. } => { - self.flush_pending_message(); + self.flush_pending_message(cx); self.pending_message = Some(AgentMessage::default()); } Text(new_text) => self.handle_text_event(new_text, event_stream, cx), @@ -1255,7 +1258,7 @@ impl Thread { self.pending_message.get_or_insert_default() } - fn flush_pending_message(&mut self) { + fn flush_pending_message(&mut self, cx: &mut Context) { let Some(mut message) = self.pending_message.take() else { return; }; @@ -1280,6 +1283,7 @@ impl Thread { } self.messages.push(Message::Agent(message)); + cx.notify() } pub(crate) fn build_completion_request( diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 38036ad3c4..40517e49a0 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -2487,12 +2487,15 @@ impl AcpThreadView { return; }; - thread.update(cx, |thread, _cx| { + thread.update(cx, |thread, cx| { let current_mode = thread.completion_mode(); - thread.set_completion_mode(match current_mode { - CompletionMode::Burn => CompletionMode::Normal, - CompletionMode::Normal => CompletionMode::Burn, - }); + thread.set_completion_mode( + match current_mode { + CompletionMode::Burn => CompletionMode::Normal, + CompletionMode::Normal => CompletionMode::Burn, + }, + cx, + ); }); } @@ -3274,8 +3277,8 @@ impl AcpThreadView { .tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use.")) .on_click({ cx.listener(move |this, _, _window, cx| { - thread.update(cx, |thread, _cx| { - thread.set_completion_mode(CompletionMode::Burn); + thread.update(cx, |thread, cx| { + thread.set_completion_mode(CompletionMode::Burn, cx); }); this.resume_chat(cx); }) From fc076e84caf2c6de0856ca92a1b49d6f6e50fcf4 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 18 Aug 2025 11:24:31 -0600 Subject: [PATCH 18/25] History mostly working Co-authored-by: Antonio Scandurra --- crates/agent2/Cargo.toml | 1 + crates/agent2/src/agent.rs | 68 ++++++++++++++++++++++ crates/agent2/src/agent2.rs | 3 +- crates/agent2/src/db.rs | 3 + crates/agent2/src/history_store.rs | 19 +++--- crates/agent2/src/thread.rs | 1 + crates/agent_ui/src/acp/thread_history.rs | 4 +- crates/agent_ui/src/agent_panel.rs | 6 +- crates/language_model/src/fake_provider.rs | 5 ++ 9 files changed, 95 insertions(+), 15 deletions(-) diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 88da875930..a32b4fe939 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -66,6 +66,7 @@ assistant_context.workspace = true [dev-dependencies] agent = { workspace = true, "features" = ["test-support"] } +acp_thread = { workspace = true, "features" = ["test-support"] } ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 6de5445d80..82ee339426 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -263,15 +263,20 @@ impl NativeAgent { } fn save_thread(&mut self, thread: Entity, cx: &mut Context) { + dbg!(); let id = thread.read(cx).id().clone(); + dbg!(); let Some(session) = self.sessions.get_mut(&id) else { return; }; + dbg!(); let thread = thread.downgrade(); let thread_database = self.thread_database.clone(); + dbg!(); session.save_task = cx.spawn(async move |this, cx| { cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; + dbg!(); let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await; thread_database.save_thread(id, db_thread).await?; this.update(cx, |this, cx| this.reload_history(cx))?; @@ -1049,12 +1054,15 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume { #[cfg(test)] mod tests { + use crate::{HistoryEntry, HistoryStore}; + use super::*; use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; use fs::FakeFs; use gpui::TestAppContext; use serde_json::json; use settings::SettingsStore; + use util::path; #[gpui::test] async fn test_maintaining_project_context(cx: &mut TestAppContext) { @@ -1229,6 +1237,66 @@ mod tests { ); } + #[gpui::test] + async fn test_history(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + + let agent = NativeAgent::new( + project.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let model = cx.update(|cx| { + LanguageModelRegistry::global(cx) + .read(cx) + .default_model() + .unwrap() + .model + }); + let connection = NativeAgentConnection(agent.clone()); + let history_store = cx.new(|cx| { + let mut store = HistoryStore::new(cx); + store.register_agent(NATIVE_AGENT_SERVER_NAME.clone(), &connection, cx); + store + }); + + let acp_thread = cx + .update(|cx| { + Rc::new(connection.clone()).new_thread(project.clone(), Path::new(path!("")), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let selector = connection.model_selector().unwrap(); + + let model = cx + .update(|cx| selector.selected_model(&session_id, cx)) + .await + .expect("selected_model should succeed"); + let model = cx + .update(|cx| agent.read(cx).models().model_from_id(&model.id)) + .unwrap(); + let model = model.as_fake(); + + let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Hi", cx)); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + model.send_last_completion_stream_text_chunk("Hey"); + model.end_last_completion_stream(); + dbg!(send.await.unwrap()); + cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE); + + let history = history_store.update(cx, |store, cx| store.entries(cx)); + assert_eq!(history.len(), 1); + assert_eq!(history[0].title(), "Hi"); + } + fn init_test(cx: &mut TestAppContext) { env_logger::try_init().ok(); cx.update(|cx| { diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index eee9810cef..6d1d266ada 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,6 +1,6 @@ mod agent; mod db; -pub mod history_store; +mod history_store; mod native_agent_server; mod templates; mod thread; @@ -11,6 +11,7 @@ mod tests; pub use agent::*; pub use db::*; +pub use history_store::*; pub use native_agent_server::NativeAgentServer; pub use templates::*; pub use thread::*; diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index a7240df5c7..d5d882bcde 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -386,6 +386,9 @@ impl ThreadsDatabase { #[cfg(test)] mod tests { + use crate::NativeAgent; + use crate::Templates; + use super::*; use agent::MessageSegment; use agent::context::LoadedContext; diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs index f4e53c4c23..d7d0ba2874 100644 --- a/crates/agent2/src/history_store.rs +++ b/crates/agent2/src/history_store.rs @@ -13,33 +13,34 @@ const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json"; const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50); +// todo!(put this in the UI) #[derive(Clone, Debug)] pub enum HistoryEntry { - Thread(AcpThreadMetadata), - Context(SavedContextMetadata), + AcpThread(AcpThreadMetadata), + TextThread(SavedContextMetadata), } impl HistoryEntry { pub fn updated_at(&self) -> DateTime { match self { - HistoryEntry::Thread(thread) => thread.updated_at, - HistoryEntry::Context(context) => context.mtime.to_utc(), + HistoryEntry::AcpThread(thread) => thread.updated_at, + HistoryEntry::TextThread(context) => context.mtime.to_utc(), } } pub fn id(&self) -> HistoryEntryId { match self { - HistoryEntry::Thread(thread) => { + HistoryEntry::AcpThread(thread) => { HistoryEntryId::Thread(thread.agent.clone(), thread.id.clone()) } - HistoryEntry::Context(context) => HistoryEntryId::Context(context.path.clone()), + HistoryEntry::TextThread(context) => HistoryEntryId::Context(context.path.clone()), } } pub fn title(&self) -> &SharedString { match self { - HistoryEntry::Thread(thread) => &thread.title, - HistoryEntry::Context(context) => &context.title, + HistoryEntry::AcpThread(thread) => &thread.title, + HistoryEntry::TextThread(context) => &context.title, } } } @@ -107,7 +108,7 @@ impl HistoryStore { self.agents .values_mut() .flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?") - .map(HistoryEntry::Thread), + .map(HistoryEntry::AcpThread), ); // todo!() include the text threads in here. diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 7ea5ff7cc6..0d81da7a92 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1283,6 +1283,7 @@ impl Thread { } self.messages.push(Message::Agent(message)); + dbg!("!!!!!!!!!!!!!!!!!!!!!!!"); cx.notify() } diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index f4750281f9..ed13912682 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -236,10 +236,10 @@ impl AcpThreadHistory { for (idx, entry) in all_entries.iter().enumerate() { match entry { - HistoryEntry::Thread(thread) => { + HistoryEntry::AcpThread(thread) => { candidates.push(StringMatchCandidate::new(idx, &thread.title)); } - HistoryEntry::Context(context) => { + HistoryEntry::TextThread(context) => { candidates.push(StringMatchCandidate::new(idx, &context.title)); } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 102facadd1..6d1e7eb846 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -6,7 +6,7 @@ use std::time::Duration; use acp_thread::AcpThreadMetadata; use agent_servers::AgentServer; -use agent2::history_store::HistoryEntry; +use agent2::HistoryEntry; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use serde::{Deserialize, Serialize}; @@ -752,7 +752,7 @@ impl AgentPanel { &acp_history, window, |this, _, event, window, cx| match event { - ThreadHistoryEvent::Open(HistoryEntry::Thread(thread)) => { + ThreadHistoryEvent::Open(HistoryEntry::AcpThread(thread)) => { let agent_choice = match thread.agent.0.as_ref() { "Claude Code" => Some(ExternalAgent::ClaudeCode), "Gemini" => Some(ExternalAgent::Gemini), @@ -761,7 +761,7 @@ impl AgentPanel { }; this.new_external_thread(agent_choice, Some(thread.clone()), window, cx); } - ThreadHistoryEvent::Open(HistoryEntry::Context(thread)) => { + ThreadHistoryEvent::Open(HistoryEntry::TextThread(thread)) => { todo!() } }, diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index a9c7d5c034..d219cb6e35 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -102,6 +102,8 @@ pub struct FakeLanguageModel { impl Default for FakeLanguageModel { fn default() -> Self { + dbg!("default......"); + eprintln!("{}", std::backtrace::Backtrace::force_capture()); Self { provider_id: LanguageModelProviderId::from("fake".to_string()), provider_name: LanguageModelProviderName::from("Fake".to_string()), @@ -149,12 +151,14 @@ impl FakeLanguageModel { } pub fn end_completion_stream(&self, request: &LanguageModelRequest) { + dbg!("remove..."); self.current_completion_txs .lock() .retain(|(req, _)| req != request); } pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into) { + dbg!("read..."); self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk); } @@ -223,6 +227,7 @@ impl LanguageModel for FakeLanguageModel { >, > { let (tx, rx) = mpsc::unbounded(); + dbg!("insert..."); self.current_completion_txs.lock().push((request, tx)); async move { Ok(rx.map(Ok).boxed()) }.boxed() } From cc196427f01550c03405cf1009258c6c1a308d1e Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 18 Aug 2025 11:25:22 -0600 Subject: [PATCH 19/25] Remove dbg! Co-authored-by: Antonio Scandurra --- crates/agent2/src/agent.rs | 9 +-------- crates/agent2/src/thread.rs | 1 - crates/agent2/src/tools/edit_file_tool.rs | 1 - crates/agent_ui/src/acp/thread_history.rs | 1 - crates/language_model/src/fake_provider.rs | 5 ----- 5 files changed, 1 insertion(+), 16 deletions(-) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 82ee339426..c7f0840062 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -263,20 +263,15 @@ impl NativeAgent { } fn save_thread(&mut self, thread: Entity, cx: &mut Context) { - dbg!(); let id = thread.read(cx).id().clone(); - dbg!(); let Some(session) = self.sessions.get_mut(&id) else { return; }; - dbg!(); let thread = thread.downgrade(); let thread_database = self.thread_database.clone(); - dbg!(); session.save_task = cx.spawn(async move |this, cx| { cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; - dbg!(); let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await; thread_database.save_thread(id, db_thread).await?; this.update(cx, |this, cx| this.reload_history(cx))?; @@ -285,13 +280,11 @@ impl NativeAgent { } fn reload_history(&mut self, cx: &mut Context) { - dbg!(""); let thread_database = self.thread_database.clone(); self.load_history = cx.spawn(async move |this, cx| { let results = cx .background_spawn(async move { let results = thread_database.list_threads().await?; - dbg!(&results); anyhow::Ok( results .into_iter() @@ -1289,7 +1282,7 @@ mod tests { cx.run_until_parked(); model.send_last_completion_stream_text_chunk("Hey"); model.end_last_completion_stream(); - dbg!(send.await.unwrap()); + send.await.unwrap(); cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE); let history = history_store.update(cx, |store, cx| store.entries(cx)); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 0d81da7a92..7ea5ff7cc6 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1283,7 +1283,6 @@ impl Thread { } self.messages.push(Message::Agent(message)); - dbg!("!!!!!!!!!!!!!!!!!!!!!!!"); cx.notify() } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index c320e8ea72..f48ea7e86a 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -440,7 +440,6 @@ impl AgentTool for EditFileTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Result<()> { - dbg!(&output); event_stream.update_diff(cx.new(|cx| { Diff::finalized( output.input_path, diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index ed13912682..7cc3ed3b9b 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -98,7 +98,6 @@ impl AcpThreadHistory { }) .detach(); - dbg!("hello!"); let search_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); editor.set_placeholder_text("Search threads...", cx); diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index d219cb6e35..a9c7d5c034 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -102,8 +102,6 @@ pub struct FakeLanguageModel { impl Default for FakeLanguageModel { fn default() -> Self { - dbg!("default......"); - eprintln!("{}", std::backtrace::Backtrace::force_capture()); Self { provider_id: LanguageModelProviderId::from("fake".to_string()), provider_name: LanguageModelProviderName::from("Fake".to_string()), @@ -151,14 +149,12 @@ impl FakeLanguageModel { } pub fn end_completion_stream(&self, request: &LanguageModelRequest) { - dbg!("remove..."); self.current_completion_txs .lock() .retain(|(req, _)| req != request); } pub fn send_last_completion_stream_text_chunk(&self, chunk: impl Into) { - dbg!("read..."); self.send_completion_stream_text_chunk(self.pending_completions().last().unwrap(), chunk); } @@ -227,7 +223,6 @@ impl LanguageModel for FakeLanguageModel { >, > { let (tx, rx) = mpsc::unbounded(); - dbg!("insert..."); self.current_completion_txs.lock().push((request, tx)); async move { Ok(rx.map(Ok).boxed()) }.boxed() } From 5d88de13dac1bc0ccc800a2e545e7f2fb032bfca Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 18 Aug 2025 14:16:03 -0600 Subject: [PATCH 20/25] Saving history with thread titles --- crates/acp_thread/src/acp_thread.rs | 7 + crates/agent2/src/agent.rs | 60 +++++++-- crates/agent2/src/db.rs | 2 - crates/agent2/src/history_store.rs | 3 +- crates/agent2/src/tests/mod.rs | 1 + crates/agent2/src/thread.rs | 155 ++++++++++++++++++++-- crates/agent2/src/tools/edit_file_tool.rs | 143 +++----------------- crates/agent_ui/src/acp/thread_history.rs | 11 +- crates/agent_ui/src/acp/thread_view.rs | 1 + crates/agent_ui/src/agent_diff.rs | 41 +++--- 10 files changed, 241 insertions(+), 183 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index a0e62c29e3..033c5dd93c 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -691,6 +691,7 @@ pub struct AcpThread { pub enum AcpThreadEvent { NewEntry, + TitleUpdated, EntryUpdated(usize), EntriesRemoved(Range), ToolAuthorizationRequired, @@ -934,6 +935,12 @@ impl AcpThread { cx.emit(AcpThreadEvent::NewEntry); } + pub fn update_title(&mut self, title: SharedString, cx: &mut Context) -> Result<()> { + self.title = title; + cx.emit(AcpThreadEvent::TitleUpdated); + Ok(()) + } + pub fn update_tool_call( &mut self, update: impl Into, diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index c7f0840062..3e382d3864 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -255,6 +255,9 @@ impl NativeAgent { this.sessions.remove(acp_thread.session_id()); }), cx.observe(&thread, |this, thread, cx| { + thread.update(cx, |thread, cx| { + thread.generate_title_if_needed(cx); + }); this.save_thread(thread.clone(), cx) }), ], @@ -262,13 +265,14 @@ impl NativeAgent { ); } - fn save_thread(&mut self, thread: Entity, cx: &mut Context) { - let id = thread.read(cx).id().clone(); + fn save_thread(&mut self, thread_handle: Entity, cx: &mut Context) { + let thread = thread_handle.read(cx); + let id = thread.id().clone(); let Some(session) = self.sessions.get_mut(&id) else { return; }; - let thread = thread.downgrade(); + let thread = thread_handle.downgrade(); let thread_database = self.thread_database.clone(); session.save_task = cx.spawn(async move |this, cx| { cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; @@ -507,7 +511,7 @@ impl NativeAgent { fn handle_models_updated_event( &mut self, - _registry: Entity, + registry: Entity, _event: &language_model::Event, cx: &mut Context, ) { @@ -518,6 +522,11 @@ impl NativeAgent { if let Some(model) = self.models.model_from_id(&model_id) { thread.set_model(model.clone(), cx); } + let summarization_model = registry + .read(cx) + .thread_summary_model() + .map(|model| model.model.clone()); + thread.set_summarization_model(summarization_model, cx); }); } } @@ -641,6 +650,10 @@ impl NativeAgentConnection { thread.update_tool_call(update, cx) })??; } + ThreadEvent::TitleUpdate(title) => { + acp_thread + .update(cx, |thread, cx| thread.update_title(title, cx))??; + } ThreadEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse { stop_reason }); @@ -821,6 +834,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { ) })?; + let summarization_model = registry.thread_summary_model().map(|c| c.model); + let thread = cx.new(|cx| { let mut thread = Thread::new( session_id.clone(), @@ -830,6 +845,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { action_log.clone(), agent.templates.clone(), default_model, + summarization_model, cx, ); Self::register_tools(&mut thread, project, action_log, cx); @@ -894,7 +910,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Create Thread let thread = agent.update(cx, |agent, cx| { - let configured_model = LanguageModelRegistry::global(cx) + let language_model_registry = LanguageModelRegistry::global(cx); + let configured_model = language_model_registry .update(cx, |registry, cx| { db_thread .model @@ -915,6 +932,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .model_from_id(&LanguageModels::model_id(&configured_model.model)) .context("no model by id")?; + let summarization_model = language_model_registry + .read(cx) + .thread_summary_model() + .map(|c| c.model); + let thread = cx.new(|cx| { let mut thread = Thread::from_db( session_id, @@ -925,6 +947,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { action_log.clone(), agent.templates.clone(), model, + summarization_model, cx, ); Self::register_tools(&mut thread, project, action_log, cx); @@ -1047,12 +1070,13 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume { #[cfg(test)] mod tests { - use crate::{HistoryEntry, HistoryStore}; + use crate::HistoryStore; use super::*; use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; use fs::FakeFs; use gpui::TestAppContext; + use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; use util::path; @@ -1245,13 +1269,6 @@ mod tests { ) .await .unwrap(); - let model = cx.update(|cx| { - LanguageModelRegistry::global(cx) - .read(cx) - .default_model() - .unwrap() - .model - }); let connection = NativeAgentConnection(agent.clone()); let history_store = cx.new(|cx| { let mut store = HistoryStore::new(cx); @@ -1268,6 +1285,16 @@ mod tests { let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); let selector = connection.model_selector().unwrap(); + let summarization_model: Arc = + Arc::new(FakeLanguageModel::default()) as _; + + agent.update(cx, |agent, cx| { + let thread = agent.sessions.get(&session_id).unwrap().thread.clone(); + thread.update(cx, |thread, cx| { + thread.set_summarization_model(Some(summarization_model.clone()), cx); + }) + }); + let model = cx .update(|cx| selector.selected_model(&session_id, cx)) .await @@ -1283,11 +1310,16 @@ mod tests { model.send_last_completion_stream_text_chunk("Hey"); model.end_last_completion_stream(); send.await.unwrap(); + + summarization_model + .as_fake() + .send_last_completion_stream_text_chunk("Saying Hello"); + summarization_model.as_fake().end_last_completion_stream(); cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE); let history = history_store.update(cx, |store, cx| store.entries(cx)); assert_eq!(history.len(), 1); - assert_eq!(history[0].title(), "Hi"); + assert_eq!(history[0].title(), "Saying Hello"); } fn init_test(cx: &mut TestAppContext) { diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index d5d882bcde..67dc8c5e98 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -386,8 +386,6 @@ impl ThreadsDatabase { #[cfg(test)] mod tests { - use crate::NativeAgent; - use crate::Templates; use super::*; use agent::MessageSegment; diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs index d7d0ba2874..0622dd4f58 100644 --- a/crates/agent2/src/history_store.rs +++ b/crates/agent2/src/history_store.rs @@ -1,12 +1,11 @@ use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName}; use agent_client_protocol as acp; -use anyhow::{Context as _, Result}; use assistant_context::SavedContextMetadata; use chrono::{DateTime, Utc}; use collections::HashMap; use gpui::{SharedString, Task, prelude::*}; use serde::{Deserialize, Serialize}; -use smol::stream::StreamExt; + use std::{path::Path, sync::Arc, time::Duration}; const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 2a4d306290..a7fc8d907a 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1506,6 +1506,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { action_log, templates, model.clone(), + None, cx, ) }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 7ea5ff7cc6..9048f7099b 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -5,7 +5,7 @@ use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot}; use agent_client_protocol as acp; -use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use chrono::{DateTime, Utc}; @@ -24,7 +24,7 @@ use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, TokenUsage, }; use project::{ Project, @@ -75,6 +75,18 @@ impl Message { } } + pub fn to_request(&self) -> Vec { + match self { + Message::User(message) => vec![message.to_request()], + Message::Agent(message) => message.to_request(), + Message::Resume => vec![LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: false, + }], + } + } + pub fn to_markdown(&self) -> String { match self { Message::User(message) => message.to_markdown(), @@ -82,6 +94,13 @@ impl Message { Message::Resume => "[resumed after tool use limit was reached]".into(), } } + + pub fn role(&self) -> Role { + match self { + Message::User(_) | Message::Resume => Role::User, + Message::Agent(_) => Role::Assistant, + } + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -426,6 +445,7 @@ pub enum ThreadEvent { ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + TitleUpdate(SharedString), Stop(acp::StopReason), } @@ -475,6 +495,7 @@ pub struct Thread { project_context: Rc>, templates: Arc, model: Arc, + summarization_model: Option>, project: Entity, action_log: Entity, } @@ -488,6 +509,7 @@ impl Thread { action_log: Entity, templates: Arc, model: Arc, + summarization_model: Option>, cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); @@ -516,11 +538,37 @@ impl Thread { project_context, templates, model, + summarization_model, project, action_log, } } + #[cfg(any(test, feature = "test-support"))] + pub fn test( + model: Arc, + project: Entity, + action_log: Entity, + cx: &mut Context, + ) -> Self { + use crate::generate_session_id; + + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + + Self::new( + generate_session_id(), + project, + Rc::default(), + context_server_registry, + action_log, + Templates::new(), + model, + None, + cx, + ) + } + pub fn id(&self) -> &acp::SessionId { &self.id } @@ -534,6 +582,7 @@ impl Thread { action_log: Entity, templates: Arc, model: Arc, + summarization_model: Option>, cx: &mut Context, ) -> Self { let profile_id = db_thread @@ -558,6 +607,7 @@ impl Thread { project_context, templates, model, + summarization_model, project, action_log, updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list) @@ -807,6 +857,15 @@ impl Thread { cx.notify() } + pub fn set_summarization_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + self.summarization_model = model; + cx.notify() + } + pub fn completion_mode(&self) -> CompletionMode { self.completion_mode } @@ -1018,6 +1077,86 @@ impl Thread { events_rx } + pub fn generate_title_if_needed(&mut self, cx: &mut Context) { + if !matches!(self.title, ThreadTitle::None) { + return; + } + + // todo!() copy logic from agent1 re: tool calls, etc.? + if self.messages.len() < 2 { + return; + } + + self.generate_title(cx); + } + + fn generate_title(&mut self, cx: &mut Context) { + let Some(model) = self.summarization_model.clone() else { + println!("No thread summary model"); + return; + }; + let mut request = LanguageModelRequest { + intent: Some(CompletionIntent::ThreadSummarization), + temperature: AgentSettings::temperature_for_model(&model, cx), + ..Default::default() + }; + + for message in &self.messages { + request.messages.extend(message.to_request()); + } + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(SUMMARIZE_THREAD_PROMPT.into())], + cache: false, + }); + + let task = cx.spawn(async move |this, cx| { + let result = async { + let mut messages = model.stream_completion(request, &cx).await?; + + let mut new_summary = String::new(); + while let Some(event) = messages.next().await { + let Ok(event) = event else { + continue; + }; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { .. }, + ) => { + // this.update(cx, |thread, cx| { + // thread.update_model_request_usage(amount as u32, limit, cx); + // })?; + // todo!()? not sure if this is the right place to do this. + continue; + } + _ => continue, + }; + + let mut lines = text.lines(); + new_summary.extend(lines.next()); + + // Stop if the LLM generated multiple lines. + if lines.next().is_some() { + break; + } + } + + anyhow::Ok(new_summary.into()) + } + .await; + + this.update(cx, |this, cx| { + this.title = ThreadTitle::Done(result); + cx.notify(); + }) + .log_err(); + }); + + self.title = ThreadTitle::Pending(task); + } + pub fn build_system_message(&self) -> LanguageModelRequestMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { @@ -1373,15 +1512,7 @@ impl Thread { ); let mut messages = vec![self.build_system_message()]; for message in &self.messages { - match message { - Message::User(message) => messages.push(message.to_request()), - Message::Agent(message) => messages.extend(message.to_request()), - Message::Resume => messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec!["Continue where you left off".into()], - cache: false, - }), - } + messages.extend(message.to_request()); } if let Some(message) = self.pending_message.as_ref() { @@ -1924,7 +2055,7 @@ impl From for acp::ContentBlock { annotations: None, uri: None, }), - UserMessageContent::Mention { uri, content } => { + UserMessageContent::Mention { .. } => { todo!() } } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index f48ea7e86a..f540349f82 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -521,7 +521,6 @@ fn resolve_path( #[cfg(test)] mod tests { use super::*; - use crate::{ContextServerRegistry, Templates, generate_session_id}; use action_log::ActionLog; use client::TelemetrySettings; use fs::Fs; @@ -529,7 +528,6 @@ mod tests { use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; - use std::rc::Rc; use util::path; #[gpui::test] @@ -541,21 +539,8 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log, - Templates::new(), - model, - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); let result = cx .update(|cx| { let input = EditFileToolInput { @@ -743,21 +728,8 @@ mod tests { }); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log.clone(), cx)); // First, test with format_on_save enabled cx.update(|cx| { @@ -885,22 +857,9 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log, cx)); // First, test with remove_trailing_whitespace_on_save enabled cx.update(|cx| { @@ -1015,22 +974,10 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); fs.insert_tree("/root", json!({})).await; @@ -1154,22 +1101,10 @@ mod tests { fs.insert_tree("/project", json!({})).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let action_log = cx.new(|_| ActionLog::new(project.clone())); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project, - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test global config paths - these should require confirmation if they exist and are outside the project @@ -1266,21 +1201,9 @@ mod tests { .await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project.clone(), - Rc::default(), - context_server_registry.clone(), - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test files in different worktrees @@ -1349,21 +1272,9 @@ mod tests { let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project.clone(), - Rc::default(), - context_server_registry.clone(), - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test edge cases @@ -1435,21 +1346,9 @@ mod tests { let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project.clone(), - Rc::default(), - context_server_registry.clone(), - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); // Test different EditFileMode values @@ -1518,21 +1417,9 @@ mod tests { let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); let action_log = cx.new(|_| ActionLog::new(project.clone())); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - Thread::new( - generate_session_id(), - project.clone(), - Rc::default(), - context_server_registry, - action_log.clone(), - Templates::new(), - model.clone(), - cx, - ) - }); + let thread = cx.new(|cx| Thread::test(model, project, action_log, cx)); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry)); assert_eq!( diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index 7cc3ed3b9b..344790f26e 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -1,15 +1,12 @@ -use crate::{AgentPanel, RemoveSelectedThread}; +use crate::RemoveSelectedThread; use agent_servers::AgentServer; -use agent2::{ - NativeAgentServer, - history_store::{HistoryEntry, HistoryStore}, -}; +use agent2::{HistoryEntry, HistoryStore, NativeAgentServer}; use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; use editor::{Editor, EditorEvent}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, - UniformListScrollHandle, WeakEntity, Window, uniform_list, + UniformListScrollHandle, Window, uniform_list, }; use project::Project; use std::{fmt::Display, ops::Range, sync::Arc}; @@ -72,7 +69,7 @@ impl AcpThreadHistory { window: &mut Window, cx: &mut Context, ) -> Self { - let history_store = cx.new(|cx| agent2::history_store::HistoryStore::new(cx)); + let history_store = cx.new(|cx| agent2::HistoryStore::new(cx)); let agent = NativeAgentServer::new(project.read(cx).fs().clone()); diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 40517e49a0..959c152525 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -687,6 +687,7 @@ impl AcpThreadView { AcpThreadEvent::ServerExited(status) => { self.thread_state = ThreadState::ServerExited { status: *status }; } + AcpThreadEvent::TitleUpdated => {} } cx.notify(); } diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index b9e1ea5d0a..3d43c6883d 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -199,24 +199,21 @@ impl AgentDiffPane { let action_log = thread.action_log(cx).clone(); let mut this = Self { - _subscriptions: [ - Some( - cx.observe_in(&action_log, window, |this, _action_log, window, cx| { - this.update_excerpts(window, cx) - }), - ), + _subscriptions: vec![ + cx.observe_in(&action_log, window, |this, _action_log, window, cx| { + this.update_excerpts(window, cx) + }), match &thread { - AgentDiffThread::Native(thread) => { - Some(cx.subscribe(&thread, |this, _thread, event, cx| { - this.handle_thread_event(event, cx) - })) - } - AgentDiffThread::AcpThread(_) => None, + AgentDiffThread::Native(thread) => cx + .subscribe(&thread, |this, _thread, event, cx| { + this.handle_native_thread_event(event, cx) + }), + AgentDiffThread::AcpThread(thread) => cx + .subscribe(&thread, |this, _thread, event, cx| { + this.handle_acp_thread_event(event, cx) + }), }, - ] - .into_iter() - .flatten() - .collect(), + ], title: SharedString::default(), multibuffer, editor, @@ -324,13 +321,20 @@ impl AgentDiffPane { } } - fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { + fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context) { match event { ThreadEvent::SummaryGenerated => self.update_title(cx), _ => {} } } + fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context) { + match event { + AcpThreadEvent::TitleUpdated => self.update_title(cx), + _ => {} + } + } + pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) { if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) { self.editor.update(cx, |editor, cx| { @@ -1521,7 +1525,8 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } } - AcpThreadEvent::EntriesRemoved(_) + AcpThreadEvent::TitleUpdated + | AcpThreadEvent::EntriesRemoved(_) | AcpThreadEvent::Stopped | AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::Error From 8373884cdb047e642f13b4c6ff70efcb09bb6317 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 18 Aug 2025 14:24:53 -0600 Subject: [PATCH 21/25] TEMP --- crates/agent2/src/agent.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 3e382d3864..17693f687c 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -139,6 +139,7 @@ impl LanguageModels { &self, model_id: &acp_thread::AgentModelId, ) -> Option> { + dbg!(&self.models.len()); self.models.get(model_id).cloned() } @@ -823,6 +824,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let default_model = registry .default_model() .and_then(|default_model| { + dbg!("here!"); agent .models .model_from_id(&LanguageModels::model_id(&default_model.model)) From 3ed2b7691b7fcc0b116711bc10d52792cf31e70c Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 18 Aug 2025 15:09:43 -0600 Subject: [PATCH 22/25] Generating thread title --- crates/agent2/src/agent.rs | 4 +-- crates/agent2/src/tests/mod.rs | 1 + crates/agent2/src/thread.rs | 34 +++++++++++++++++------ crates/agent2/src/tools/edit_file_tool.rs | 4 ++- crates/agent_ui/src/agent_panel.rs | 6 ++-- 5 files changed, 34 insertions(+), 15 deletions(-) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 5b5dbff589..be0bac047f 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -139,7 +139,6 @@ impl LanguageModels { &self, model_id: &acp_thread::AgentModelId, ) -> Option> { - dbg!(&self.models.len()); self.models.get(model_id).cloned() } @@ -277,6 +276,7 @@ impl NativeAgent { let thread_database = self.thread_database.clone(); session.save_task = cx.spawn(async move |this, cx| { cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; + let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await; thread_database.save_thread(id, db_thread).await?; this.update(cx, |this, cx| this.reload_history(cx))?; @@ -527,7 +527,7 @@ impl NativeAgent { if thread.model().is_none() && let Some(model) = default_model.clone() { - thread.set_model(model); + thread.set_model(model, cx); cx.notify(); } let summarization_model = registry diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 2678a51126..9aac27dcd6 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1554,6 +1554,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { action_log, templates, Some(model.clone()), + None, cx, ) }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 52430d67b9..93a6fad23a 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -495,6 +495,7 @@ pub struct Thread { project_context: Rc>, templates: Arc, model: Option>, + summarization_model: Option>, project: Entity, action_log: Entity, } @@ -508,6 +509,7 @@ impl Thread { action_log: Entity, templates: Arc, model: Option>, + summarization_model: Option>, cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); @@ -561,7 +563,7 @@ impl Thread { context_server_registry, action_log, Templates::new(), - model, + Some(model), None, cx, ) @@ -604,7 +606,7 @@ impl Thread { profile_id, project_context, templates, - model, + model: Some(model), summarization_model, project, action_log, @@ -622,9 +624,9 @@ impl Thread { initial_project_snapshot: None, cumulative_token_usage: self.cumulative_token_usage.clone(), request_token_usage: self.request_token_usage.clone(), - model: Some(DbLanguageModel { - provider: self.model.provider_id().to_string(), - model: self.model.name().0.to_string(), + model: self.model.as_ref().map(|model| DbLanguageModel { + provider: model.provider_id().to_string(), + model: model.name().0.to_string(), }), completion_mode: Some(self.completion_mode.into()), profile: Some(self.profile_id.clone()), @@ -850,8 +852,18 @@ impl Thread { self.model.as_ref() } - pub fn set_model(&mut self, model: Arc) { + pub fn set_model(&mut self, model: Arc, cx: &mut Context) { self.model = Some(model); + cx.notify() + } + + pub fn set_summarization_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + self.summarization_model = model; + cx.notify() } pub fn completion_mode(&self) -> CompletionMode { @@ -931,7 +943,7 @@ impl Thread { id: UserMessageId, content: impl IntoIterator, cx: &mut Context, - ) -> mpsc::UnboundedReceiver> + ) -> Result>> where T: Into, { @@ -951,10 +963,13 @@ impl Thread { self.run_turn(cx) } - fn run_turn(&mut self, cx: &mut Context) -> mpsc::UnboundedReceiver> { + fn run_turn( + &mut self, + cx: &mut Context, + ) -> Result>> { self.cancel(cx); - let model = self.model.clone(); + let model = self.model.clone().context("No language model configured")?; let (events_tx, events_rx) = mpsc::unbounded::>(); let event_stream = ThreadEventStream(events_tx); let message_ix = self.messages.len().saturating_sub(1); @@ -1145,6 +1160,7 @@ impl Thread { }); self.title = ThreadTitle::Pending(task); + cx.notify() } pub fn build_system_message(&self) -> LanguageModelRequestMessage { diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index f540349f82..756698bf3f 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -257,8 +257,10 @@ impl AgentTool for EditFileTool { let (request, model, action_log) = self.thread.update(cx, |thread, cx| { let request = thread.build_completion_request(CompletionIntent::ToolResults, cx); - (request, thread.model().clone(), thread.action_log().clone()) + (request, thread.model().cloned(), thread.action_log().clone()) })?; + let request = request?; + let model = model.context("No language model configured")?; let edit_format = EditFormat::from_model(model.clone())?; let edit_agent = EditAgent::new( diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 24599ab621..20e6206fa2 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1697,13 +1697,13 @@ impl AgentPanel { window.dispatch_action(NewTextThread.boxed_clone(), cx); } AgentType::NativeAgent => { - self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), window, cx) + self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), None, window, cx) } AgentType::Gemini => { - self.new_external_thread(Some(crate::ExternalAgent::Gemini), window, cx) + self.new_external_thread(Some(crate::ExternalAgent::Gemini), None, window, cx) } AgentType::ClaudeCode => { - self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), window, cx) + self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), None, window, cx) } } } From 8998cdee2647977dc7acbc6b4ccd8a98b3074069 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 18 Aug 2025 21:03:31 -0600 Subject: [PATCH 23/25] Wire through title update --- crates/acp_thread/src/acp_thread.rs | 1 + crates/agent2/src/agent.rs | 19 +++-- crates/agent2/src/history_store.rs | 1 + crates/agent2/src/thread.rs | 112 ++++++++++++++++------------ 4 files changed, 82 insertions(+), 51 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index bb9c2e35ea..58e171046f 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -937,6 +937,7 @@ impl AcpThread { } pub fn update_title(&mut self, title: SharedString, cx: &mut Context) -> Result<()> { + dbg!("update title", &title); self.title = title; cx.emit(AcpThreadEvent::TitleUpdated); Ok(()) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index be0bac047f..564511c632 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -244,20 +244,28 @@ impl NativeAgent { cx: &mut Context, ) { let id = thread.read(cx).id().clone(); + let weak_thread = acp_thread.downgrade(); self.sessions.insert( id, Session { thread: thread.clone(), - acp_thread: acp_thread.downgrade(), + acp_thread: weak_thread.clone(), save_task: Task::ready(Ok(())), _subscriptions: vec![ cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); }), - cx.observe(&thread, |this, thread, cx| { - thread.update(cx, |thread, cx| { - thread.generate_title_if_needed(cx); - }); + cx.observe(&thread, move |this, thread, cx| { + if let Some(response_stream) = + thread.update(cx, |thread, cx| thread.generate_title_if_needed(cx)) + { + NativeAgentConnection::handle_thread_events( + response_stream, + weak_thread.clone(), + cx, + ) + .detach_and_log_err(cx); + } this.save_thread(thread.clone(), cx) }), ], @@ -659,6 +667,7 @@ impl NativeAgentConnection { })??; } ThreadEvent::TitleUpdate(title) => { + dbg!("updating title"); acp_thread .update(cx, |thread, cx| thread.update_title(title, cx))??; } diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs index 0622dd4f58..151cc4e8a9 100644 --- a/crates/agent2/src/history_store.rs +++ b/crates/agent2/src/history_store.rs @@ -87,6 +87,7 @@ impl HistoryStore { let history = AgentHistory { entries: history.clone(), _task: cx.spawn(async move |this, cx| { + dbg!("loaded", history.borrow().as_ref().map(|b| b.len())); while history.changed().await.is_ok() { this.update(cx, |_, cx| cx.notify()).ok(); } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 93a6fad23a..639c50957e 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -458,7 +458,7 @@ pub struct ToolCallAuthorization { enum ThreadTitle { None, - Pending(Task<()>), + Pending(Shared>), Done(Result), } @@ -1082,24 +1082,33 @@ impl Thread { Ok(events_rx) } - pub fn generate_title_if_needed(&mut self, cx: &mut Context) { + pub fn generate_title_if_needed( + &mut self, + cx: &mut Context, + ) -> Option>> { if !matches!(self.title, ThreadTitle::None) { - return; + return None; } // todo!() copy logic from agent1 re: tool calls, etc.? if self.messages.len() < 2 { - return; + return None; } + let Some(model) = self.summarization_model.clone() else { + return None; + }; + let (tx, rx) = mpsc::unbounded(); - self.generate_title(cx); + self.generate_title(model, ThreadEventStream(tx), cx); + Some(rx) } - fn generate_title(&mut self, cx: &mut Context) { - let Some(model) = self.summarization_model.clone() else { - println!("No thread summary model"); - return; - }; + fn generate_title( + &mut self, + model: Arc, + event_stream: ThreadEventStream, + cx: &mut Context, + ) { let mut request = LanguageModelRequest { intent: Some(CompletionIntent::ThreadSummarization), temperature: AgentSettings::temperature_for_model(&model, cx), @@ -1116,50 +1125,55 @@ impl Thread { cache: false, }); - let task = cx.spawn(async move |this, cx| { - let result = async { - let mut messages = model.stream_completion(request, &cx).await?; + let task = cx + .spawn(async move |this, cx| { + let result: anyhow::Result = async { + let mut messages = model.stream_completion(request, &cx).await?; - let mut new_summary = String::new(); - while let Some(event) = messages.next().await { - let Ok(event) = event else { - continue; - }; - let text = match event { - LanguageModelCompletionEvent::Text(text) => text, - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::UsageUpdated { .. }, - ) => { - // this.update(cx, |thread, cx| { - // thread.update_model_request_usage(amount as u32, limit, cx); - // })?; - // todo!()? not sure if this is the right place to do this. + let mut new_summary = String::new(); + while let Some(event) = messages.next().await { + let Ok(event) = event else { continue; + }; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { .. }, + ) => { + // this.update(cx, |thread, cx| { + // thread.update_model_request_usage(amount as u32, limit, cx); + // })?; + // todo!()? not sure if this is the right place to do this. + continue; + } + _ => continue, + }; + + let mut lines = text.lines(); + new_summary.extend(lines.next()); + + // Stop if the LLM generated multiple lines. + if lines.next().is_some() { + break; } - _ => continue, - }; - - let mut lines = text.lines(); - new_summary.extend(lines.next()); - - // Stop if the LLM generated multiple lines. - if lines.next().is_some() { - break; } + + anyhow::Ok(new_summary.into()) } + .await; - anyhow::Ok(new_summary.into()) - } - .await; - - this.update(cx, |this, cx| { - this.title = ThreadTitle::Done(result); - cx.notify(); + this.update(cx, |this, cx| { + if let Ok(title) = &result { + event_stream.send_title_update(title.clone()); + } + this.title = ThreadTitle::Done(result); + cx.notify(); + }) + .log_err(); }) - .log_err(); - }); + .shared(); - self.title = ThreadTitle::Pending(task); + self.title = ThreadTitle::Pending(task.clone()); cx.notify() } @@ -1746,6 +1760,12 @@ impl ThreadEventStream { .ok(); } + fn send_title_update(&self, text: SharedString) { + self.0 + .unbounded_send(Ok(ThreadEvent::TitleUpdate(text))) + .ok(); + } + fn send_thinking(&self, text: &str) { self.0 .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string()))) From 3b7ad6236d7dde386e0a518684e2d2a366629f40 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 18 Aug 2025 23:02:51 -0600 Subject: [PATCH 24/25] Re-wire history --- crates/acp_thread/src/acp_thread.rs | 1 - crates/acp_thread/src/connection.rs | 36 ++- crates/agent2/src/agent.rs | 261 +++++++++++----------- crates/agent2/src/db.rs | 28 ++- crates/agent2/src/history_store.rs | 96 ++++++-- crates/agent_ui/src/acp/thread_history.rs | 33 +-- crates/agent_ui/src/acp/thread_view.rs | 47 +++- crates/agent_ui/src/agent_panel.rs | 2 + 8 files changed, 294 insertions(+), 210 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 58e171046f..bb9c2e35ea 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -937,7 +937,6 @@ impl AcpThread { } pub fn update_title(&mut self, title: SharedString, cx: &mut Context) -> Result<()> { - dbg!("update title", &title); self.title = title; cx.emit(AcpThreadEvent::TitleUpdated); Ok(()) diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 1e3272a6b0..af653a1c74 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -2,6 +2,7 @@ use crate::{AcpThread, AcpThreadMetadata}; use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; +use futures::channel::mpsc::UnboundedReceiver; use gpui::{Entity, SharedString, Task}; use project::Project; use serde::{Deserialize, Serialize}; @@ -26,25 +27,6 @@ pub trait AgentConnection { cx: &mut App, ) -> Task>>; - // todo!(expose a history trait, and include list_threads and load_thread) - // todo!(write a test) - fn list_threads( - &self, - _cx: &mut App, - ) -> Option>>> { - return None; - } - - fn load_thread( - self: Rc, - _project: Entity, - _cwd: &Path, - _session_id: acp::SessionId, - _cx: &mut App, - ) -> Task>> { - Task::ready(Err(anyhow::anyhow!("load thread not implemented"))) - } - fn auth_methods(&self) -> &[acp::AuthMethod]; fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; @@ -82,6 +64,10 @@ pub trait AgentConnection { None } + fn history(self: Rc) -> Option> { + None + } + fn into_any(self: Rc) -> Rc; } @@ -99,6 +85,18 @@ pub trait AgentSessionResume { fn run(&self, cx: &mut App) -> Task>; } +pub trait AgentHistory { + fn list_threads(&self, cx: &mut App) -> Task>>; + fn observe_history(&self, cx: &mut App) -> UnboundedReceiver; + fn load_thread( + self: Rc, + _project: Entity, + _cwd: &Path, + _session_id: acp::SessionId, + _cx: &mut App, + ) -> Task>>; +} + #[derive(Debug)] pub struct AuthRequired; diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 564511c632..cc3a40f652 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -6,7 +6,7 @@ use crate::{ UserMessageContent, WebSearchTool, templates::Templates, }; use crate::{ThreadsDatabase, generate_session_id}; -use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; +use acp_thread::{AcpThread, AcpThreadMetadata, AgentHistory, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; @@ -56,7 +56,7 @@ struct Session { thread: Entity, /// The ACP thread that handles protocol communication acp_thread: WeakEntity, - save_task: Task>, + save_task: Task<()>, _subscriptions: Vec, } @@ -173,8 +173,7 @@ pub struct NativeAgent { project: Entity, prompt_store: Option>, thread_database: Arc, - history: watch::Sender>>, - load_history: Task<()>, + history_watchers: Vec>, fs: Arc, _subscriptions: Vec, } @@ -212,7 +211,7 @@ impl NativeAgent { let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = watch::channel(()); - let mut this = Self { + Self { sessions: HashMap::new(), project_context: Rc::new(RefCell::new(project_context)), project_context_needs_refresh: project_context_needs_refresh_tx, @@ -228,12 +227,9 @@ impl NativeAgent { project, prompt_store, fs, - history: watch::channel(None).0, - load_history: Task::ready(()), + history_watchers: Vec::new(), _subscriptions: subscriptions, - }; - this.reload_history(cx); - this + } }) } @@ -250,7 +246,7 @@ impl NativeAgent { Session { thread: thread.clone(), acp_thread: weak_thread.clone(), - save_task: Task::ready(Ok(())), + save_task: Task::ready(()), _subscriptions: vec![ cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); @@ -285,35 +281,23 @@ impl NativeAgent { session.save_task = cx.spawn(async move |this, cx| { cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; - let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await; - thread_database.save_thread(id, db_thread).await?; - this.update(cx, |this, cx| this.reload_history(cx))?; - Ok(()) - }); - } - - fn reload_history(&mut self, cx: &mut Context) { - let thread_database = self.thread_database.clone(); - self.load_history = cx.spawn(async move |this, cx| { - let results = cx - .background_spawn(async move { - let results = thread_database.list_threads().await?; - anyhow::Ok( - results - .into_iter() - .map(|thread| AcpThreadMetadata { - agent: NATIVE_AGENT_SERVER_NAME.clone(), - id: thread.id.into(), - title: thread.title, - updated_at: thread.updated_at, - }) - .collect(), - ) + let Some(task) = thread.update(cx, |thread, cx| thread.to_db(cx)).ok() else { + return; + }; + let db_thread = task.await; + let metadata = thread_database + .save_thread(id.clone(), db_thread) + .await + .log_err(); + if let Some(metadata) = metadata { + this.update(cx, |this, _| { + for watcher in this.history_watchers.iter_mut() { + watcher + .unbounded_send(metadata.clone().to_acp(NATIVE_AGENT_SERVER_NAME)) + .log_err(); + } }) - .await; - if let Some(results) = results.log_err() { - this.update(cx, |this, _| this.history.send(Some(results))) - .ok(); + .ok(); } }); } @@ -667,7 +651,6 @@ impl NativeAgentConnection { })??; } ThreadEvent::TitleUpdate(title) => { - dbg!("updating title"); acp_thread .update(cx, |thread, cx| thread.update_title(title, cx))??; } @@ -884,11 +867,106 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Task::ready(Ok(())) } - fn list_threads( + fn model_selector(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + + fn prompt( &self, + id: Option, + params: acp::PromptRequest, cx: &mut App, - ) -> Option>>> { - Some(self.0.read(cx).history.receiver()) + ) -> Task> { + let id = id.expect("UserMessageId is required"); + let session_id = params.session_id.clone(); + log::info!("Received prompt request for session: {}", session_id); + log::debug!("Prompt blocks count: {}", params.prompt.len()); + + self.run_turn(session_id, cx, |thread, cx| { + let content: Vec = params + .prompt + .into_iter() + .map(Into::into) + .collect::>(); + log::info!("Converted prompt to message: {} chars", content.len()); + log::debug!("Message id: {:?}", id); + log::debug!("Message content: {:?}", content); + + thread.update(cx, |thread, cx| thread.send(id, content, cx)) + }) + } + + fn resume( + &self, + session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + Some(Rc::new(NativeAgentSessionResume { + connection: self.clone(), + session_id: session_id.clone(), + }) as _) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + log::info!("Cancelling on session: {}", session_id); + self.0.update(cx, |agent, cx| { + if let Some(agent) = agent.sessions.get(session_id) { + agent.thread.update(cx, |thread, cx| thread.cancel(cx)); + } + }); + } + + fn session_editor( + &self, + session_id: &agent_client_protocol::SessionId, + cx: &mut App, + ) -> Option> { + self.0.update(cx, |agent, _cx| { + agent + .sessions + .get(session_id) + .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _) + }) + } + + fn history(self: Rc) -> Option> { + Some(self) + } + + fn into_any(self: Rc) -> Rc { + self + } +} + +struct NativeAgentSessionEditor(Entity); + +impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { + fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { + Task::ready( + self.0 + .update(cx, |thread, cx| thread.truncate(message_id, cx)), + ) + } +} + +impl acp_thread::AgentHistory for NativeAgentConnection { + fn list_threads(&self, cx: &mut App) -> Task>> { + let database = self.0.read(cx).thread_database.clone(); + cx.background_executor().spawn(async move { + let threads = database.list_threads().await?; + anyhow::Ok( + threads + .into_iter() + .map(|thread| thread.to_acp(NATIVE_AGENT_SERVER_NAME)) + .collect::>(), + ) + }) + } + + fn observe_history(&self, cx: &mut App) -> mpsc::UnboundedReceiver { + let (tx, rx) = mpsc::unbounded(); + self.0.update(cx, |this, _| this.history_watchers.push(tx)); + rx } fn load_thread( @@ -980,83 +1058,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Ok(acp_thread) }) } - - fn model_selector(&self) -> Option> { - Some(Rc::new(self.clone()) as Rc) - } - - fn prompt( - &self, - id: Option, - params: acp::PromptRequest, - cx: &mut App, - ) -> Task> { - let id = id.expect("UserMessageId is required"); - let session_id = params.session_id.clone(); - log::info!("Received prompt request for session: {}", session_id); - log::debug!("Prompt blocks count: {}", params.prompt.len()); - - self.run_turn(session_id, cx, |thread, cx| { - let content: Vec = params - .prompt - .into_iter() - .map(Into::into) - .collect::>(); - log::info!("Converted prompt to message: {} chars", content.len()); - log::debug!("Message id: {:?}", id); - log::debug!("Message content: {:?}", content); - - thread.update(cx, |thread, cx| thread.send(id, content, cx)) - }) - } - - fn resume( - &self, - session_id: &acp::SessionId, - _cx: &mut App, - ) -> Option> { - Some(Rc::new(NativeAgentSessionResume { - connection: self.clone(), - session_id: session_id.clone(), - }) as _) - } - - fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { - log::info!("Cancelling on session: {}", session_id); - self.0.update(cx, |agent, cx| { - if let Some(agent) = agent.sessions.get(session_id) { - agent.thread.update(cx, |thread, cx| thread.cancel(cx)); - } - }); - } - - fn session_editor( - &self, - session_id: &agent_client_protocol::SessionId, - cx: &mut App, - ) -> Option> { - self.0.update(cx, |agent, _cx| { - agent - .sessions - .get(session_id) - .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _) - }) - } - - fn into_any(self: Rc) -> Rc { - self - } -} - -struct NativeAgentSessionEditor(Entity); - -impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { - fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { - Task::ready( - self.0 - .update(cx, |thread, cx| thread.truncate(message_id, cx)), - ) - } } struct NativeAgentSessionResume { @@ -1274,16 +1275,22 @@ mod tests { ) .await .unwrap(); - let connection = NativeAgentConnection(agent.clone()); - let history_store = cx.new(|cx| { - let mut store = HistoryStore::new(cx); - store.register_agent(NATIVE_AGENT_SERVER_NAME.clone(), &connection, cx); - store - }); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + let history = connection.clone().history().unwrap(); + let history_store = cx.new(|cx| HistoryStore::get_or_init(cx)); + + history_store + .update(cx, |history_store, cx| { + history_store.load_history(NATIVE_AGENT_SERVER_NAME.clone(), history.as_ref(), cx) + }) + .await + .unwrap(); let acp_thread = cx .update(|cx| { - Rc::new(connection.clone()).new_thread(project.clone(), Path::new(path!("")), cx) + connection + .clone() + .new_thread(project.clone(), Path::new(path!("")), cx) }) .await .unwrap(); diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index 67dc8c5e98..43979e8c74 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -1,4 +1,5 @@ use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; +use acp_thread::{AcpThreadMetadata, AgentServerName}; use agent::thread_store; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; @@ -30,6 +31,17 @@ pub struct DbThreadMetadata { pub updated_at: DateTime, } +impl DbThreadMetadata { + pub fn to_acp(self, agent: AgentServerName) -> AcpThreadMetadata { + AcpThreadMetadata { + agent, + id: self.id, + title: self.title, + updated_at: self.updated_at, + } + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct DbThread { pub title: SharedString, @@ -288,7 +300,7 @@ impl ThreadsDatabase { connection: &Arc>, id: acp::SessionId, thread: DbThread, - ) -> Result<()> { + ) -> Result { let json_data = serde_json::to_string(&thread)?; let title = thread.title.to_string(); let updated_at = thread.updated_at.to_rfc3339(); @@ -303,9 +315,13 @@ impl ThreadsDatabase { INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?) "})?; - insert((id.0, title, updated_at, data_type, data))?; + insert((id.0.clone(), title, updated_at, data_type, data))?; - Ok(()) + Ok(DbThreadMetadata { + id, + title: thread.title, + updated_at: thread.updated_at, + }) } pub fn list_threads(&self) -> Task>> { @@ -360,7 +376,11 @@ impl ThreadsDatabase { }) } - pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task> { + pub fn save_thread( + &self, + id: acp::SessionId, + thread: DbThread, + ) -> Task> { let connection = self.connection.clone(); self.executor diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs index 151cc4e8a9..996702bff7 100644 --- a/crates/agent2/src/history_store.rs +++ b/crates/agent2/src/history_store.rs @@ -1,12 +1,17 @@ use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName}; use agent_client_protocol as acp; +use agent_servers::AgentServer; use assistant_context::SavedContextMetadata; use chrono::{DateTime, Utc}; use collections::HashMap; -use gpui::{SharedString, Task, prelude::*}; +use gpui::{Entity, Global, SharedString, Task, prelude::*}; +use project::Project; use serde::{Deserialize, Serialize}; +use ui::App; -use std::{path::Path, sync::Arc, time::Duration}; +use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; + +use crate::NativeAgentServer; const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json"; @@ -59,41 +64,88 @@ enum SerializedRecentOpen { Context(String), } +#[derive(Default)] pub struct AgentHistory { - entries: watch::Receiver>>, - _task: Task<()>, + entries: HashMap, + loaded: bool, } pub struct HistoryStore { agents: HashMap, // todo!() text threads } +// note, we have to share the history store between all windows +// because we only get updates from one connection at a time. +struct GlobalHistoryStore(Entity); +impl Global for GlobalHistoryStore {} impl HistoryStore { - pub fn new(_cx: &mut Context) -> Self { + pub fn get_or_init(project: &Entity, cx: &mut App) -> Entity { + if cx.has_global::() { + return cx.global::().0.clone(); + } + let history_store = cx.new(|cx| HistoryStore::new(cx)); + cx.set_global(GlobalHistoryStore(history_store.clone())); + let root_dir = project + .read(cx) + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).abs_path()) + .unwrap_or_else(|| paths::home_dir().as_path().into()); + + let agent = NativeAgentServer::new(project.read(cx).fs().clone()); + let connect = agent.connect(&root_dir, project, cx); + cx.spawn({ + let history_store = history_store.clone(); + async move |cx| { + let connection = connect.await?.history().unwrap(); + history_store + .update(cx, |history_store, cx| { + history_store.load_history(agent.name(), connection.as_ref(), cx) + })? + .await + } + }) + .detach_and_log_err(cx); + history_store + } + + fn new(_cx: &mut Context) -> Self { Self { agents: HashMap::default(), } } - pub fn register_agent( + pub fn update_history(&mut self, entry: AcpThreadMetadata, cx: &mut Context) { + let agent = self + .agents + .entry(entry.agent.clone()) + .or_insert(Default::default()); + + agent.entries.insert(entry.id.clone(), entry); + cx.notify() + } + + pub fn load_history( &mut self, agent_name: AgentServerName, - connection: &dyn AgentConnection, + connection: &dyn acp_thread::AgentHistory, cx: &mut Context, - ) { - let Some(mut history) = connection.list_threads(cx) else { - return; - }; - let history = AgentHistory { - entries: history.clone(), - _task: cx.spawn(async move |this, cx| { - dbg!("loaded", history.borrow().as_ref().map(|b| b.len())); - while history.changed().await.is_ok() { - this.update(cx, |_, cx| cx.notify()).ok(); - } - }), - }; - self.agents.insert(agent_name.clone(), history); + ) -> Task> { + let threads = connection.list_threads(cx); + cx.spawn(async move |this, cx| { + let threads = threads.await?; + + this.update(cx, |this, cx| { + this.agents.insert( + agent_name, + AgentHistory { + loaded: true, + entries: threads.into_iter().map(|t| (t.id.clone(), t)).collect(), + }, + ); + cx.notify() + }) + }) } pub fn entries(&mut self, _cx: &mut Context) -> Vec { @@ -107,7 +159,7 @@ impl HistoryStore { history_entries.extend( self.agents .values_mut() - .flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?") + .flat_map(|history| history.entries.values().cloned()) // todo!("surface the loading state?") .map(HistoryEntry::AcpThread), ); // todo!() include the text threads in here. diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index 344790f26e..d0bf60ad72 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -5,8 +5,8 @@ use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; use editor::{Editor, EditorEvent}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ - App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, - UniformListScrollHandle, Window, uniform_list, + App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ScrollStrategy, Stateful, + Task, UniformListScrollHandle, Window, uniform_list, }; use project::Project; use std::{fmt::Display, ops::Range, sync::Arc}; @@ -18,7 +18,7 @@ use ui::{ use util::ResultExt; pub struct AcpThreadHistory { - history_store: Entity, + pub(crate) history_store: Entity, scroll_handle: UniformListScrollHandle, selected_index: usize, hovered_index: Option, @@ -69,37 +69,12 @@ impl AcpThreadHistory { window: &mut Window, cx: &mut Context, ) -> Self { - let history_store = cx.new(|cx| agent2::HistoryStore::new(cx)); - - let agent = NativeAgentServer::new(project.read(cx).fs().clone()); - - let root_dir = project - .read(cx) - .visible_worktrees(cx) - .next() - .map(|worktree| worktree.read(cx).abs_path()) - .unwrap_or_else(|| paths::home_dir().as_path().into()); - - // todo!() reuse this connection for sending messages - let connect = agent.connect(&root_dir, project, cx); - cx.spawn(async move |this, cx| { - let connection = connect.await?; - this.update(cx, |this, cx| { - this.history_store.update(cx, |this, cx| { - this.register_agent(agent.name(), connection.as_ref(), cx) - }) - })?; - // todo!() we must keep it alive - std::mem::forget(connection); - anyhow::Ok(()) - }) - .detach(); - let search_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); editor.set_placeholder_text("Search threads...", cx); editor }); + let history_store = HistoryStore::get_or_init(project, cx); let search_editor_subscription = cx.subscribe(&search_editor, |this, search_editor, event, cx| { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 18eeda267d..676739da3b 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -18,6 +18,7 @@ use editor::scroll::Autoscroll; use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; use fs::Fs; +use futures::StreamExt; use gpui::{ Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, @@ -123,6 +124,7 @@ pub struct AcpThreadView { editor_expanded: bool, terminal_expanded: bool, editing_message: Option, + history_store: Entity, _cancel_task: Option>, _subscriptions: [Subscription; 3], } @@ -134,6 +136,7 @@ enum ThreadState { Ready { thread: Entity, _subscription: [Subscription; 2], + _history_task: Option>, }, LoadError(LoadError), Unauthenticated { @@ -149,6 +152,7 @@ impl AcpThreadView { agent: Rc, workspace: WeakEntity, project: Entity, + history_store: Entity, thread_store: Entity, text_thread_store: Entity, restore_thread: Option, @@ -196,11 +200,13 @@ impl AcpThreadView { thread_state: Self::initial_state( agent, restore_thread, + history_store.clone(), workspace, project, window, cx, ), + history_store, message_editor, model_selector: None, profile_selector: None, @@ -225,6 +231,7 @@ impl AcpThreadView { fn initial_state( agent: Rc, restore_thread: Option, + history_store: Entity, workspace: WeakEntity, project: Entity, window: &mut Window, @@ -251,6 +258,25 @@ impl AcpThreadView { } }; + let mut history_task = None; + let history = connection.clone().history(); + if let Some(history) = history.clone() { + if let Some(mut history) = cx.update(|_, cx| history.observe_history(cx)).ok() { + history_task = Some(cx.spawn(async move |cx| { + while let Some(update) = history.next().await { + if !history_store + .update(cx, |history_store, cx| { + history_store.update_history(update, cx) + }) + .is_ok() + { + break; + } + } + })); + } + } + // this.update_in(cx, |_this, _window, cx| { // let status = connection.exit_status(cx); // cx.spawn(async move |this, cx| { @@ -264,15 +290,12 @@ impl AcpThreadView { // .detach(); // }) // .ok(); - // + let history = connection.clone().history(); let task = cx.update(|_, cx| { - if let Some(restore_thread) = restore_thread { - connection.clone().load_thread( - project.clone(), - &root_dir, - restore_thread.id, - cx, - ) + if let Some(restore_thread) = restore_thread + && let Some(history) = history + { + history.load_thread(project.clone(), &root_dir, restore_thread.id, cx) } else { connection .clone() @@ -342,6 +365,7 @@ impl AcpThreadView { this.thread_state = ThreadState::Ready { thread, _subscription: [thread_subscription, action_log_subscription], + _history_task: history_task, }; this.profile_selector = this.as_native_thread(cx).map(|thread| { @@ -751,6 +775,7 @@ impl AcpThreadView { this.thread_state = Self::initial_state( agent, None, // todo!() + this.history_store.clone(), this.workspace.clone(), project.clone(), window, @@ -3755,6 +3780,8 @@ pub(crate) mod tests { cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx))); let text_thread_store = cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); + let history_store = + cx.update(|_window, cx| cx.new(|cx| agent2::HistoryStore::get_or_init(cx))); let thread_view = cx.update(|window, cx| { cx.new(|cx| { @@ -3762,6 +3789,7 @@ pub(crate) mod tests { Rc::new(agent), workspace.downgrade(), project, + history_store.clone(), thread_store.clone(), text_thread_store.clone(), None, @@ -3954,6 +3982,8 @@ pub(crate) mod tests { cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx))); let text_thread_store = cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx))); + let history_store = + cx.update(|_window, cx| cx.new(|cx| agent2::HistoryStore::get_or_init(cx))); let connection = Rc::new(StubAgentConnection::new()); let thread_view = cx.update(|window, cx| { @@ -3962,6 +3992,7 @@ pub(crate) mod tests { Rc::new(StubAgentServer::new(connection.as_ref().clone())), workspace.downgrade(), project.clone(), + history_store, thread_store.clone(), text_thread_store.clone(), None, diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 20e6206fa2..8392c5589b 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1031,11 +1031,13 @@ impl AgentPanel { }; this.update_in(cx, |this, window, cx| { + let acp_history_store = this.acp_history.read(cx).history_store.clone(); let thread_view = cx.new(|cx| { crate::acp::AcpThreadView::new( server, workspace.clone(), project, + acp_history_store, thread_store.clone(), text_thread_store.clone(), restore_thread, From 67e7d1426cbd726289c8064d495c0b54e65e59c6 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 18 Aug 2025 23:13:49 -0600 Subject: [PATCH 25/25] WIP --- crates/agent2/src/db.rs | 3 ++- crates/agent2/src/thread.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index 43979e8c74..afc4fdcb3f 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -367,8 +367,9 @@ impl ThreadsDatabase { } DataType::Json => String::from_utf8(data)?, }; + dbg!(&json_data); - let thread = DbThread::from_json(json_data.as_bytes())?; + let thread = dbg!(DbThread::from_json(json_data.as_bytes()))?; Ok(Some(thread)) } else { Ok(None) diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 639c50957e..b3c62a3a64 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -610,7 +610,7 @@ impl Thread { summarization_model, project, action_log, - updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list) + updated_at: db_thread.updated_at, } }