diff --git a/Cargo.lock b/Cargo.lock index 59299885d7..8e6f4b656a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -222,6 +222,7 @@ dependencies = [ "log", "lsp", "open", + "parking_lot", "paths", "portable-pty", "pretty_assertions", @@ -234,6 +235,7 @@ dependencies = [ "serde_json", "settings", "smol", + "sqlez", "task", "tempfile", "terminal", @@ -250,6 +252,7 @@ dependencies = [ "workspace-hack", "worktree", "zlog", + "zstd", ] [[package]] diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 12c94a522d..e24a5ec782 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -893,7 +893,7 @@ impl ThreadsDatabase { let needs_migration_from_heed = mdb_path.exists(); - let connection = if *ZED_STATELESS { + let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) { Connection::open_memory(Some("THREAD_FALLBACK_DB")) } else { Connection::open_file(&sqlite_path.to_string_lossy()) diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 60a327c73b..de1b1f74ec 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -38,6 +38,7 @@ language_model.workspace = true language_models.workspace = true log.workspace = true open.workspace = true +parking_lot.workspace = true paths.workspace = true portable-pty.workspace = true project.workspace = true @@ -47,6 +48,7 @@ schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +sqlez.workspace = true smol.workspace = true task.workspace = true terminal.workspace = true @@ -58,8 +60,10 @@ watch.workspace = true web_search.workspace = true which.workspace = true workspace-hack.workspace = true +zstd.workspace = true [dev-dependencies] +agent = { workspace = true, "features" = ["test-support"] } ctor.workspace = true client = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index e17a1121b3..e8f85b0346 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -1,17 +1,34 @@ use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; use agent::thread_store; +use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; -use anyhow::Result; +use anyhow::{Result, anyhow}; use chrono::{DateTime, Utc}; use collections::{HashMap, IndexMap}; +use futures::{FutureExt, future::Shared}; +use gpui::{BackgroundExecutor, Global, ReadGlobal, Task}; +use indoc::indoc; +use parking_lot::Mutex; use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use ui::SharedString; +use sqlez::{ + bindable::{Bind, Column}, + connection::Connection, + statement::Statement, +}; +use std::{path::PathBuf, sync::Arc}; +use ui::{App, SharedString}; pub type DbMessage = crate::Message; pub type DbSummary = agent::thread::DetailedSummaryState; pub type DbLanguageModel = thread_store::SerializedLanguageModel; -pub type DbThreadMetadata = thread_store::SerializedThreadMetadata; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbThreadMetadata { + pub id: acp::SessionId, + #[serde(alias = "summary")] + pub title: SharedString, + pub updated_at: DateTime, +} #[derive(Debug, Serialize, Deserialize)] pub struct DbThread { @@ -165,3 +182,292 @@ impl DbThread { }) } } + +pub static ZED_STATELESS: std::sync::LazyLock = + std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty())); + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DataType { + #[serde(rename = "json")] + Json, + #[serde(rename = "zstd")] + Zstd, +} + +impl Bind for DataType { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let value = match self { + DataType::Json => "json", + DataType::Zstd => "zstd", + }; + value.bind(statement, start_index) + } +} + +impl Column for DataType { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (value, next_index) = String::column(statement, start_index)?; + let data_type = match value.as_str() { + "json" => DataType::Json, + "zstd" => DataType::Zstd, + _ => anyhow::bail!("Unknown data type: {}", value), + }; + Ok((data_type, next_index)) + } +} + +struct GlobalThreadsDatabase(Shared, Arc>>>); + +impl Global for GlobalThreadsDatabase {} + +pub(crate) struct ThreadsDatabase { + executor: BackgroundExecutor, + connection: Arc>, +} + +impl ThreadsDatabase { + fn connection(&self) -> Arc> { + self.connection.clone() + } + + const COMPRESSION_LEVEL: i32 = 3; +} + +impl ThreadsDatabase { + fn global_future( + cx: &mut App, + ) -> Shared, Arc>>> { + GlobalThreadsDatabase::global(cx).0.clone() + } + + fn init(cx: &mut App) { + let executor = cx.background_executor().clone(); + let database_future = executor + .spawn({ + let executor = executor.clone(); + let threads_dir = paths::data_dir().join("threads"); + async move { + match ThreadsDatabase::new(threads_dir, executor) { + Ok(db) => Ok(Arc::new(db)), + Err(err) => Err(Arc::new(err)), + } + } + }) + .shared(); + + cx.set_global(GlobalThreadsDatabase(database_future)); + } + + pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result { + let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) { + Connection::open_memory(Some("THREAD_FALLBACK_DB")) + } else { + std::fs::create_dir_all(&threads_dir)?; + let sqlite_path = threads_dir.join("threads.db"); + Connection::open_file(&sqlite_path.to_string_lossy()) + }; + + connection.exec(indoc! {" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + summary TEXT NOT NULL, + updated_at TEXT NOT NULL, + data_type TEXT NOT NULL, + data BLOB NOT NULL + ) + "})?() + .map_err(|e| anyhow!("Failed to create threads table: {}", e))?; + + let db = Self { + executor: executor.clone(), + connection: Arc::new(Mutex::new(connection)), + }; + + Ok(db) + } + + fn save_thread_sync( + connection: &Arc>, + id: acp::SessionId, + thread: DbThread, + ) -> Result<()> { + let json_data = serde_json::to_string(&thread)?; + let title = thread.title.to_string(); + let updated_at = thread.updated_at.to_rfc3339(); + + let connection = connection.lock(); + + let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?; + let data_type = DataType::Zstd; + let data = compressed; + + let mut insert = connection.exec_bound::<(Arc, String, String, DataType, Vec)>(indoc! {" + INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?) + "})?; + + insert((id.0, title, updated_at, data_type, data))?; + + Ok(()) + } + + pub fn list_threads(&self) -> Task>> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + let mut select = + connection.select_bound::<(), (Arc, String, String)>(indoc! {" + SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC + "})?; + + let rows = select(())?; + let mut threads = Vec::new(); + + for (id, summary, updated_at) in rows { + threads.push(DbThreadMetadata { + id: acp::SessionId(id), + title: summary.into(), + updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), + }); + } + + Ok(threads) + }) + } + + pub fn load_thread(&self, id: acp::SessionId) -> Task>> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + let mut select = connection.select_bound::, (DataType, Vec)>(indoc! {" + SELECT data_type, data FROM threads WHERE id = ? LIMIT 1 + "})?; + + let rows = select(id.0)?; + if let Some((data_type, data)) = rows.into_iter().next() { + let json_data = match data_type { + DataType::Zstd => { + let decompressed = zstd::decode_all(&data[..])?; + String::from_utf8(decompressed)? + } + DataType::Json => String::from_utf8(data)?, + }; + + let thread = DbThread::from_json(json_data.as_bytes())?; + Ok(Some(thread)) + } else { + Ok(None) + } + }) + } + + pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task> { + let connection = self.connection.clone(); + + self.executor + .spawn(async move { Self::save_thread_sync(&connection, id, thread) }) + } + + pub fn delete_thread(&self, id: acp::SessionId) -> Task> { + let connection = self.connection.clone(); + + self.executor.spawn(async move { + let connection = connection.lock(); + + let mut delete = connection.exec_bound::>(indoc! {" + DELETE FROM threads WHERE id = ? + "})?; + + delete(id.0)?; + + Ok(()) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use agent::MessageSegment; + use agent::context::LoadedContext; + use client::Client; + use fs::FakeFs; + use gpui::AppContext; + use gpui::TestAppContext; + use http_client::FakeHttpClient; + use language_model::Role; + use pretty_assertions::assert_matches; + use project::Project; + use settings::SettingsStore; + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + ThreadsDatabase::init(cx); + + let http_client = FakeHttpClient::with_404_response(); + let clock = Arc::new(clock::FakeSystemClock::new()); + let client = Client::new(clock, http_client, cx); + agent::init(cx); + agent_settings::init(cx); + language_model::init(client.clone(), cx); + }); + } + + #[gpui::test] + async fn test_retrieving_old_thread(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + + // Save a thread using the old agent. + { + let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx)); + let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx)); + thread.update(cx, |thread, cx| { + thread.insert_message( + Role::User, + vec![MessageSegment::Text("Hey!".into())], + LoadedContext::default(), + vec![], + false, + cx, + ); + thread.insert_message( + Role::Assistant, + vec![MessageSegment::Text("How're you doing?".into())], + LoadedContext::default(), + vec![], + false, + cx, + ) + }); + thread_store + .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) + .await + .unwrap(); + } + + let db = cx + .update(|cx| ThreadsDatabase::global_future(cx)) + .await + .unwrap(); + let threads = db.list_threads().await.unwrap(); + assert_eq!(threads.len(), 1); + let thread = db + .load_thread(threads[0].id.clone()) + .await + .unwrap() + .unwrap(); + assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n"); + assert_eq!( + thread.messages[1].to_markdown(), + "## Assistant\n\nHow're you doing?\n" + ); + } +}