diff --git a/Cargo.lock b/Cargo.lock index bcbff9efc3..12cc4f2928 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,7 @@ dependencies = [ "serde_json_lenient", "settings", "smol", + "sqlez", "streaming_diff", "telemetry", "telemetry_events", @@ -133,6 +134,7 @@ dependencies = [ "workspace-hack", "zed_actions", "zed_llm_client", + "zstd", ] [[package]] diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index f9c6fcd4e4..c1f9d9a3fa 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -46,6 +46,7 @@ git.workspace = true gpui.workspace = true heed.workspace = true html_to_markdown.workspace = true +indoc.workspace = true http_client.workspace = true indexed_docs.workspace = true inventory.workspace = true @@ -78,6 +79,7 @@ serde_json.workspace = true serde_json_lenient.workspace = true settings.workspace = true smol.workspace = true +sqlez.workspace = true streaming_diff.workspace = true telemetry.workspace = true telemetry_events.workspace = true @@ -97,6 +99,7 @@ workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true zed_llm_client.workspace = true +zstd.workspace = true [dev-dependencies] buffer_diff = { workspace = true, features = ["test-support"] } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 8c6fc909e9..b6edbc3919 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -1,8 +1,7 @@ -use std::borrow::Cow; use std::cell::{Ref, RefCell}; use std::path::{Path, PathBuf}; use std::rc::Rc; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; @@ -17,8 +16,7 @@ use gpui::{ App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Subscription, Task, prelude::*, }; -use heed::Database; -use heed::types::SerdeBincode; + use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage}; use project::context_server_store::{ContextServerStatus, ContextServerStore}; use project::{Project, ProjectItem, ProjectPath, Worktree}; @@ -35,6 +33,42 @@ use crate::context_server_tool::ContextServerTool; use crate::thread::{ DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId, }; +use indoc::indoc; +use sqlez::{ + bindable::{Bind, Column}, + connection::Connection, + statement::Statement, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DataType { + #[serde(rename = "json")] + Json, + #[serde(rename = "zstd")] + Zstd, +} + +impl Bind for DataType { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let value = match self { + DataType::Json => "json", + DataType::Zstd => "zstd", + }; + value.bind(statement, start_index) + } +} + +impl Column for DataType { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (value, next_index) = String::column(statement, start_index)?; + let data_type = match value.as_str() { + "json" => DataType::Json, + "zstd" => DataType::Zstd, + _ => anyhow::bail!("Unknown data type: {}", value), + }; + Ok((data_type, next_index)) + } +} const RULES_FILE_NAMES: [&'static str; 6] = [ ".rules", @@ -866,25 +900,27 @@ impl Global for GlobalThreadsDatabase {} pub(crate) struct ThreadsDatabase { executor: BackgroundExecutor, - env: heed::Env, - threads: Database, SerializedThread>, + connection: Arc>, } -impl heed::BytesEncode<'_> for SerializedThread { - type EItem = SerializedThread; +impl ThreadsDatabase { + fn connection(&self) -> Arc> { + self.connection.clone() + } - fn bytes_encode(item: &Self::EItem) -> Result, heed::BoxedError> { - serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into) + const COMPRESSION_LEVEL: i32 = 3; +} + +impl Bind for ThreadId { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + self.to_string().bind(statement, start_index) } } -impl<'a> heed::BytesDecode<'a> for SerializedThread { - type DItem = SerializedThread; - - fn bytes_decode(bytes: &'a [u8]) -> Result { - // We implement this type manually because we want to call `SerializedThread::from_json`, - // instead of the Deserialize trait implementation for `SerializedThread`. - SerializedThread::from_json(bytes).map_err(Into::into) +impl Column for ThreadId { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (id_str, next_index) = String::column(statement, start_index)?; + Ok((ThreadId::from(id_str.as_str()), next_index)) } } @@ -900,8 +936,8 @@ impl ThreadsDatabase { let database_future = executor .spawn({ let executor = executor.clone(); - let database_path = paths::data_dir().join("threads/threads-db.1.mdb"); - async move { ThreadsDatabase::new(database_path, executor) } + let threads_dir = paths::data_dir().join("threads"); + async move { ThreadsDatabase::new(threads_dir, executor) } }) .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) .boxed() @@ -910,41 +946,144 @@ impl ThreadsDatabase { cx.set_global(GlobalThreadsDatabase(database_future)); } - pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result { - std::fs::create_dir_all(&path)?; + pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result { + std::fs::create_dir_all(&threads_dir)?; + + let sqlite_path = threads_dir.join("threads.db"); + let mdb_path = threads_dir.join("threads-db.1.mdb"); + + let needs_migration_from_heed = mdb_path.exists(); + + let connection = Connection::open_file(&sqlite_path.to_string_lossy()); + + connection.exec(indoc! {" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + summary TEXT NOT NULL, + updated_at TEXT NOT NULL, + data_type TEXT NOT NULL, + data BLOB NOT NULL + ) + "})?() + .map_err(|e| anyhow!("Failed to create threads table: {}", e))?; + + let db = Self { + executor: executor.clone(), + connection: Arc::new(Mutex::new(connection)), + }; + + if needs_migration_from_heed { + let db_connection = db.connection(); + let executor_clone = executor.clone(); + executor + .spawn(async move { + log::info!("Starting threads.db migration"); + Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?; + std::fs::remove_dir_all(mdb_path)?; + log::info!("threads.db migrated to sqlite"); + Ok::<(), anyhow::Error>(()) + }) + .detach(); + } + + Ok(db) + } + + // Remove this migration after 2025-09-01 + fn migrate_from_heed( + mdb_path: &Path, + connection: Arc>, + _executor: BackgroundExecutor, + ) -> Result<()> { + use heed::types::SerdeBincode; + struct SerializedThreadHeed(SerializedThread); + + impl heed::BytesEncode<'_> for SerializedThreadHeed { + type EItem = SerializedThreadHeed; + + fn bytes_encode( + item: &Self::EItem, + ) -> Result, heed::BoxedError> { + serde_json::to_vec(&item.0) + .map(std::borrow::Cow::Owned) + .map_err(Into::into) + } + } + + impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed { + type DItem = SerializedThreadHeed; + + fn bytes_decode(bytes: &'a [u8]) -> Result { + SerializedThread::from_json(bytes) + .map(SerializedThreadHeed) + .map_err(Into::into) + } + } const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024; + let env = unsafe { heed::EnvOpenOptions::new() .map_size(ONE_GB_IN_BYTES) .max_dbs(1) - .open(path)? + .open(mdb_path)? }; - let mut txn = env.write_txn()?; - let threads = env.create_database(&mut txn, Some("threads"))?; - txn.commit()?; + let txn = env.write_txn()?; + let threads: heed::Database, SerializedThreadHeed> = env + .open_database(&txn, Some("threads"))? + .ok_or_else(|| anyhow!("threads database not found"))?; - Ok(Self { - executor, - env, - threads, - }) + for result in threads.iter(&txn)? { + let (thread_id, thread_heed) = result?; + Self::save_thread_sync(&connection, thread_id, thread_heed.0)?; + } + + Ok(()) + } + + fn save_thread_sync( + connection: &Arc>, + id: ThreadId, + thread: SerializedThread, + ) -> Result<()> { + let json_data = serde_json::to_string(&thread)?; + let summary = thread.summary.to_string(); + let updated_at = thread.updated_at.to_rfc3339(); + + let connection = connection.lock().unwrap(); + + let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?; + let data_type = DataType::Zstd; + let data = compressed; + + let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec)>(indoc! {" + INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?) + "})?; + + insert((id, summary, updated_at, data_type, data))?; + + Ok(()) } pub fn list_threads(&self) -> Task>> { - let env = self.env.clone(); - let threads = self.threads; + let connection = self.connection.clone(); self.executor.spawn(async move { - let txn = env.read_txn()?; - let mut iter = threads.iter(&txn)?; + let connection = connection.lock().unwrap(); + let mut select = + connection.select_bound::<(), (ThreadId, String, String)>(indoc! {" + SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC + "})?; + + let rows = select(())?; let mut threads = Vec::new(); - while let Some((key, value)) = iter.next().transpose()? { + + for (id, summary, updated_at) in rows { threads.push(SerializedThreadMetadata { - id: key, - summary: value.summary, - updated_at: value.updated_at, + id, + summary: summary.into(), + updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), }); } @@ -953,36 +1092,51 @@ impl ThreadsDatabase { } pub fn try_find_thread(&self, id: ThreadId) -> Task>> { - let env = self.env.clone(); - let threads = self.threads; + let connection = self.connection.clone(); self.executor.spawn(async move { - let txn = env.read_txn()?; - let thread = threads.get(&txn, &id)?; - Ok(thread) + let connection = connection.lock().unwrap(); + let mut select = connection.select_bound::)>(indoc! {" + SELECT data_type, data FROM threads WHERE id = ? LIMIT 1 + "})?; + + let rows = select(id)?; + if let Some((data_type, data)) = rows.into_iter().next() { + let json_data = match data_type { + DataType::Zstd => { + let decompressed = zstd::decode_all(&data[..])?; + String::from_utf8(decompressed)? + } + DataType::Json => String::from_utf8(data)?, + }; + + let thread = SerializedThread::from_json(json_data.as_bytes())?; + Ok(Some(thread)) + } else { + Ok(None) + } }) } pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task> { - let env = self.env.clone(); - let threads = self.threads; + let connection = self.connection.clone(); - self.executor.spawn(async move { - let mut txn = env.write_txn()?; - threads.put(&mut txn, &id, &thread)?; - txn.commit()?; - Ok(()) - }) + self.executor + .spawn(async move { Self::save_thread_sync(&connection, id, thread) }) } pub fn delete_thread(&self, id: ThreadId) -> Task> { - let env = self.env.clone(); - let threads = self.threads; + let connection = self.connection.clone(); self.executor.spawn(async move { - let mut txn = env.write_txn()?; - threads.delete(&mut txn, &id)?; - txn.commit()?; + let connection = connection.lock().unwrap(); + + let mut delete = connection.exec_bound::(indoc! {" + DELETE FROM threads WHERE id = ? + "})?; + + delete(id)?; + Ok(()) }) }