From c874f1fa9d70c889602e54a8f6ac3ad8c2a2734a Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Mon, 2 Jun 2025 17:01:34 +0300 Subject: [PATCH] agent: Migrate thread storage to SQLite with zstd compression (#31741) Previously, LMDB was used for storing threads, but it consumed excessive disk space and was capped at 1GB. This change migrates thread storage to an SQLite database. Thread JSON objects are now compressed using zstd. I considered training a custom zstd dictionary and storing it in a separate table. However, the additional complexity outweighed the modest space savings (up to 20%). I ended up using the default dictionary stored with data. Threads can be exported relatively easily from outside the application: ``` $ sqlite3 threads.db "SELECT hex(data) FROM threads LIMIT 5;" | xxd -r -p | zstd -d | fx ``` Benchmarks: - Original heed database: 200MB - Sqlite uncompressed: 51MB - sqlite compressed (this PR): 4.0MB - sqlite compressed with a trained dictionary: 3.8MB Release Notes: - Migrated thread storage to SQLite with compression --- Cargo.lock | 2 + crates/agent/Cargo.toml | 3 + crates/agent/src/thread_store.rs | 266 ++++++++++++++++++++++++------- 3 files changed, 215 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bcbff9efc3..12cc4f2928 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,7 @@ dependencies = [ "serde_json_lenient", "settings", "smol", + "sqlez", "streaming_diff", "telemetry", "telemetry_events", @@ -133,6 +134,7 @@ dependencies = [ "workspace-hack", "zed_actions", "zed_llm_client", + "zstd", ] [[package]] diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index f9c6fcd4e4..c1f9d9a3fa 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -46,6 +46,7 @@ git.workspace = true gpui.workspace = true heed.workspace = true html_to_markdown.workspace = true +indoc.workspace = true http_client.workspace = true indexed_docs.workspace = true inventory.workspace = true @@ -78,6 +79,7 @@ serde_json.workspace = true serde_json_lenient.workspace = true settings.workspace = true smol.workspace = true +sqlez.workspace = true streaming_diff.workspace = true telemetry.workspace = true telemetry_events.workspace = true @@ -97,6 +99,7 @@ workspace-hack.workspace = true workspace.workspace = true zed_actions.workspace = true zed_llm_client.workspace = true +zstd.workspace = true [dev-dependencies] buffer_diff = { workspace = true, features = ["test-support"] } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 8c6fc909e9..b6edbc3919 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -1,8 +1,7 @@ -use std::borrow::Cow; use std::cell::{Ref, RefCell}; use std::path::{Path, PathBuf}; use std::rc::Rc; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; @@ -17,8 +16,7 @@ use gpui::{ App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, Subscription, Task, prelude::*, }; -use heed::Database; -use heed::types::SerdeBincode; + use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage}; use project::context_server_store::{ContextServerStatus, ContextServerStore}; use project::{Project, ProjectItem, ProjectPath, Worktree}; @@ -35,6 +33,42 @@ use crate::context_server_tool::ContextServerTool; use crate::thread::{ DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId, }; +use indoc::indoc; +use sqlez::{ + bindable::{Bind, Column}, + connection::Connection, + statement::Statement, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DataType { + #[serde(rename = "json")] + Json, + #[serde(rename = "zstd")] + Zstd, +} + +impl Bind for DataType { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let value = match self { + DataType::Json => "json", + DataType::Zstd => "zstd", + }; + value.bind(statement, start_index) + } +} + +impl Column for DataType { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (value, next_index) = String::column(statement, start_index)?; + let data_type = match value.as_str() { + "json" => DataType::Json, + "zstd" => DataType::Zstd, + _ => anyhow::bail!("Unknown data type: {}", value), + }; + Ok((data_type, next_index)) + } +} const RULES_FILE_NAMES: [&'static str; 6] = [ ".rules", @@ -866,25 +900,27 @@ impl Global for GlobalThreadsDatabase {} pub(crate) struct ThreadsDatabase { executor: BackgroundExecutor, - env: heed::Env, - threads: Database, SerializedThread>, + connection: Arc>, } -impl heed::BytesEncode<'_> for SerializedThread { - type EItem = SerializedThread; +impl ThreadsDatabase { + fn connection(&self) -> Arc> { + self.connection.clone() + } - fn bytes_encode(item: &Self::EItem) -> Result, heed::BoxedError> { - serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into) + const COMPRESSION_LEVEL: i32 = 3; +} + +impl Bind for ThreadId { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + self.to_string().bind(statement, start_index) } } -impl<'a> heed::BytesDecode<'a> for SerializedThread { - type DItem = SerializedThread; - - fn bytes_decode(bytes: &'a [u8]) -> Result { - // We implement this type manually because we want to call `SerializedThread::from_json`, - // instead of the Deserialize trait implementation for `SerializedThread`. - SerializedThread::from_json(bytes).map_err(Into::into) +impl Column for ThreadId { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (id_str, next_index) = String::column(statement, start_index)?; + Ok((ThreadId::from(id_str.as_str()), next_index)) } } @@ -900,8 +936,8 @@ impl ThreadsDatabase { let database_future = executor .spawn({ let executor = executor.clone(); - let database_path = paths::data_dir().join("threads/threads-db.1.mdb"); - async move { ThreadsDatabase::new(database_path, executor) } + let threads_dir = paths::data_dir().join("threads"); + async move { ThreadsDatabase::new(threads_dir, executor) } }) .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) .boxed() @@ -910,41 +946,144 @@ impl ThreadsDatabase { cx.set_global(GlobalThreadsDatabase(database_future)); } - pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result { - std::fs::create_dir_all(&path)?; + pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result { + std::fs::create_dir_all(&threads_dir)?; + + let sqlite_path = threads_dir.join("threads.db"); + let mdb_path = threads_dir.join("threads-db.1.mdb"); + + let needs_migration_from_heed = mdb_path.exists(); + + let connection = Connection::open_file(&sqlite_path.to_string_lossy()); + + connection.exec(indoc! {" + CREATE TABLE IF NOT EXISTS threads ( + id TEXT PRIMARY KEY, + summary TEXT NOT NULL, + updated_at TEXT NOT NULL, + data_type TEXT NOT NULL, + data BLOB NOT NULL + ) + "})?() + .map_err(|e| anyhow!("Failed to create threads table: {}", e))?; + + let db = Self { + executor: executor.clone(), + connection: Arc::new(Mutex::new(connection)), + }; + + if needs_migration_from_heed { + let db_connection = db.connection(); + let executor_clone = executor.clone(); + executor + .spawn(async move { + log::info!("Starting threads.db migration"); + Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?; + std::fs::remove_dir_all(mdb_path)?; + log::info!("threads.db migrated to sqlite"); + Ok::<(), anyhow::Error>(()) + }) + .detach(); + } + + Ok(db) + } + + // Remove this migration after 2025-09-01 + fn migrate_from_heed( + mdb_path: &Path, + connection: Arc>, + _executor: BackgroundExecutor, + ) -> Result<()> { + use heed::types::SerdeBincode; + struct SerializedThreadHeed(SerializedThread); + + impl heed::BytesEncode<'_> for SerializedThreadHeed { + type EItem = SerializedThreadHeed; + + fn bytes_encode( + item: &Self::EItem, + ) -> Result, heed::BoxedError> { + serde_json::to_vec(&item.0) + .map(std::borrow::Cow::Owned) + .map_err(Into::into) + } + } + + impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed { + type DItem = SerializedThreadHeed; + + fn bytes_decode(bytes: &'a [u8]) -> Result { + SerializedThread::from_json(bytes) + .map(SerializedThreadHeed) + .map_err(Into::into) + } + } const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024; + let env = unsafe { heed::EnvOpenOptions::new() .map_size(ONE_GB_IN_BYTES) .max_dbs(1) - .open(path)? + .open(mdb_path)? }; - let mut txn = env.write_txn()?; - let threads = env.create_database(&mut txn, Some("threads"))?; - txn.commit()?; + let txn = env.write_txn()?; + let threads: heed::Database, SerializedThreadHeed> = env + .open_database(&txn, Some("threads"))? + .ok_or_else(|| anyhow!("threads database not found"))?; - Ok(Self { - executor, - env, - threads, - }) + for result in threads.iter(&txn)? { + let (thread_id, thread_heed) = result?; + Self::save_thread_sync(&connection, thread_id, thread_heed.0)?; + } + + Ok(()) + } + + fn save_thread_sync( + connection: &Arc>, + id: ThreadId, + thread: SerializedThread, + ) -> Result<()> { + let json_data = serde_json::to_string(&thread)?; + let summary = thread.summary.to_string(); + let updated_at = thread.updated_at.to_rfc3339(); + + let connection = connection.lock().unwrap(); + + let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?; + let data_type = DataType::Zstd; + let data = compressed; + + let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec)>(indoc! {" + INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?) + "})?; + + insert((id, summary, updated_at, data_type, data))?; + + Ok(()) } pub fn list_threads(&self) -> Task>> { - let env = self.env.clone(); - let threads = self.threads; + let connection = self.connection.clone(); self.executor.spawn(async move { - let txn = env.read_txn()?; - let mut iter = threads.iter(&txn)?; + let connection = connection.lock().unwrap(); + let mut select = + connection.select_bound::<(), (ThreadId, String, String)>(indoc! {" + SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC + "})?; + + let rows = select(())?; let mut threads = Vec::new(); - while let Some((key, value)) = iter.next().transpose()? { + + for (id, summary, updated_at) in rows { threads.push(SerializedThreadMetadata { - id: key, - summary: value.summary, - updated_at: value.updated_at, + id, + summary: summary.into(), + updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), }); } @@ -953,36 +1092,51 @@ impl ThreadsDatabase { } pub fn try_find_thread(&self, id: ThreadId) -> Task>> { - let env = self.env.clone(); - let threads = self.threads; + let connection = self.connection.clone(); self.executor.spawn(async move { - let txn = env.read_txn()?; - let thread = threads.get(&txn, &id)?; - Ok(thread) + let connection = connection.lock().unwrap(); + let mut select = connection.select_bound::)>(indoc! {" + SELECT data_type, data FROM threads WHERE id = ? LIMIT 1 + "})?; + + let rows = select(id)?; + if let Some((data_type, data)) = rows.into_iter().next() { + let json_data = match data_type { + DataType::Zstd => { + let decompressed = zstd::decode_all(&data[..])?; + String::from_utf8(decompressed)? + } + DataType::Json => String::from_utf8(data)?, + }; + + let thread = SerializedThread::from_json(json_data.as_bytes())?; + Ok(Some(thread)) + } else { + Ok(None) + } }) } pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task> { - let env = self.env.clone(); - let threads = self.threads; + let connection = self.connection.clone(); - self.executor.spawn(async move { - let mut txn = env.write_txn()?; - threads.put(&mut txn, &id, &thread)?; - txn.commit()?; - Ok(()) - }) + self.executor + .spawn(async move { Self::save_thread_sync(&connection, id, thread) }) } pub fn delete_thread(&self, id: ThreadId) -> Task> { - let env = self.env.clone(); - let threads = self.threads; + let connection = self.connection.clone(); self.executor.spawn(async move { - let mut txn = env.write_txn()?; - threads.delete(&mut txn, &id)?; - txn.commit()?; + let connection = connection.lock().unwrap(); + + let mut delete = connection.exec_bound::(indoc! {" + DELETE FROM threads WHERE id = ? + "})?; + + delete(id)?; + Ok(()) }) }