Test round-tripping old threads
This commit is contained in:
parent
6b6b7e66e1
commit
fd8ea2acfc
4 changed files with 318 additions and 5 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -222,6 +222,7 @@ dependencies = [
|
|||
"log",
|
||||
"lsp",
|
||||
"open",
|
||||
"parking_lot",
|
||||
"paths",
|
||||
"portable-pty",
|
||||
"pretty_assertions",
|
||||
|
@ -234,6 +235,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"sqlez",
|
||||
"task",
|
||||
"tempfile",
|
||||
"terminal",
|
||||
|
@ -250,6 +252,7 @@ dependencies = [
|
|||
"workspace-hack",
|
||||
"worktree",
|
||||
"zlog",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
@ -893,7 +893,7 @@ impl ThreadsDatabase {
|
|||
|
||||
let needs_migration_from_heed = mdb_path.exists();
|
||||
|
||||
let connection = if *ZED_STATELESS {
|
||||
let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
|
||||
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
|
||||
} else {
|
||||
Connection::open_file(&sqlite_path.to_string_lossy())
|
||||
|
|
|
@ -38,6 +38,7 @@ language_model.workspace = true
|
|||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
open.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
portable-pty.workspace = true
|
||||
project.workspace = true
|
||||
|
@ -47,6 +48,7 @@ schemars.workspace = true
|
|||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
sqlez.workspace = true
|
||||
smol.workspace = true
|
||||
task.workspace = true
|
||||
terminal.workspace = true
|
||||
|
@ -58,8 +60,10 @@ watch.workspace = true
|
|||
web_search.workspace = true
|
||||
which.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zstd.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
agent = { workspace = true, "features" = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
clock = { workspace = true, "features" = ["test-support"] }
|
||||
|
|
|
@ -1,17 +1,34 @@
|
|||
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
||||
use agent::thread_store;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{AgentProfileId, CompletionMode};
|
||||
use anyhow::Result;
|
||||
use anyhow::{Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{HashMap, IndexMap};
|
||||
use futures::{FutureExt, future::Shared};
|
||||
use gpui::{BackgroundExecutor, Global, ReadGlobal, Task};
|
||||
use indoc::indoc;
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::SharedString;
|
||||
use sqlez::{
|
||||
bindable::{Bind, Column},
|
||||
connection::Connection,
|
||||
statement::Statement,
|
||||
};
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use ui::{App, SharedString};
|
||||
|
||||
pub type DbMessage = crate::Message;
|
||||
pub type DbSummary = agent::thread::DetailedSummaryState;
|
||||
pub type DbLanguageModel = thread_store::SerializedLanguageModel;
|
||||
pub type DbThreadMetadata = thread_store::SerializedThreadMetadata;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbThreadMetadata {
|
||||
pub id: acp::SessionId,
|
||||
#[serde(alias = "summary")]
|
||||
pub title: SharedString,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DbThread {
|
||||
|
@ -165,3 +182,292 @@ impl DbThread {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub static ZED_STATELESS: std::sync::LazyLock<bool> =
|
||||
std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DataType {
|
||||
#[serde(rename = "json")]
|
||||
Json,
|
||||
#[serde(rename = "zstd")]
|
||||
Zstd,
|
||||
}
|
||||
|
||||
impl Bind for DataType {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
let value = match self {
|
||||
DataType::Json => "json",
|
||||
DataType::Zstd => "zstd",
|
||||
};
|
||||
value.bind(statement, start_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for DataType {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let (value, next_index) = String::column(statement, start_index)?;
|
||||
let data_type = match value.as_str() {
|
||||
"json" => DataType::Json,
|
||||
"zstd" => DataType::Zstd,
|
||||
_ => anyhow::bail!("Unknown data type: {}", value),
|
||||
};
|
||||
Ok((data_type, next_index))
|
||||
}
|
||||
}
|
||||
|
||||
struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
|
||||
|
||||
impl Global for GlobalThreadsDatabase {}
|
||||
|
||||
pub(crate) struct ThreadsDatabase {
|
||||
executor: BackgroundExecutor,
|
||||
connection: Arc<Mutex<Connection>>,
|
||||
}
|
||||
|
||||
impl ThreadsDatabase {
|
||||
fn connection(&self) -> Arc<Mutex<Connection>> {
|
||||
self.connection.clone()
|
||||
}
|
||||
|
||||
const COMPRESSION_LEVEL: i32 = 3;
|
||||
}
|
||||
|
||||
impl ThreadsDatabase {
|
||||
fn global_future(
|
||||
cx: &mut App,
|
||||
) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
|
||||
GlobalThreadsDatabase::global(cx).0.clone()
|
||||
}
|
||||
|
||||
fn init(cx: &mut App) {
|
||||
let executor = cx.background_executor().clone();
|
||||
let database_future = executor
|
||||
.spawn({
|
||||
let executor = executor.clone();
|
||||
let threads_dir = paths::data_dir().join("threads");
|
||||
async move {
|
||||
match ThreadsDatabase::new(threads_dir, executor) {
|
||||
Ok(db) => Ok(Arc::new(db)),
|
||||
Err(err) => Err(Arc::new(err)),
|
||||
}
|
||||
}
|
||||
})
|
||||
.shared();
|
||||
|
||||
cx.set_global(GlobalThreadsDatabase(database_future));
|
||||
}
|
||||
|
||||
pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
|
||||
let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
|
||||
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
|
||||
} else {
|
||||
std::fs::create_dir_all(&threads_dir)?;
|
||||
let sqlite_path = threads_dir.join("threads.db");
|
||||
Connection::open_file(&sqlite_path.to_string_lossy())
|
||||
};
|
||||
|
||||
connection.exec(indoc! {"
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
id TEXT PRIMARY KEY,
|
||||
summary TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
data_type TEXT NOT NULL,
|
||||
data BLOB NOT NULL
|
||||
)
|
||||
"})?()
|
||||
.map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
|
||||
|
||||
let db = Self {
|
||||
executor: executor.clone(),
|
||||
connection: Arc::new(Mutex::new(connection)),
|
||||
};
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
fn save_thread_sync(
|
||||
connection: &Arc<Mutex<Connection>>,
|
||||
id: acp::SessionId,
|
||||
thread: DbThread,
|
||||
) -> Result<()> {
|
||||
let json_data = serde_json::to_string(&thread)?;
|
||||
let title = thread.title.to_string();
|
||||
let updated_at = thread.updated_at.to_rfc3339();
|
||||
|
||||
let connection = connection.lock();
|
||||
|
||||
let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
|
||||
let data_type = DataType::Zstd;
|
||||
let data = compressed;
|
||||
|
||||
let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
|
||||
INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
|
||||
"})?;
|
||||
|
||||
insert((id.0, title, updated_at, data_type, data))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
|
||||
let connection = self.connection.clone();
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let connection = connection.lock();
|
||||
let mut select =
|
||||
connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
|
||||
SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
|
||||
"})?;
|
||||
|
||||
let rows = select(())?;
|
||||
let mut threads = Vec::new();
|
||||
|
||||
for (id, summary, updated_at) in rows {
|
||||
threads.push(DbThreadMetadata {
|
||||
id: acp::SessionId(id),
|
||||
title: summary.into(),
|
||||
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(threads)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
|
||||
let connection = self.connection.clone();
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let connection = connection.lock();
|
||||
let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
|
||||
SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
|
||||
"})?;
|
||||
|
||||
let rows = select(id.0)?;
|
||||
if let Some((data_type, data)) = rows.into_iter().next() {
|
||||
let json_data = match data_type {
|
||||
DataType::Zstd => {
|
||||
let decompressed = zstd::decode_all(&data[..])?;
|
||||
String::from_utf8(decompressed)?
|
||||
}
|
||||
DataType::Json => String::from_utf8(data)?,
|
||||
};
|
||||
|
||||
let thread = DbThread::from_json(json_data.as_bytes())?;
|
||||
Ok(Some(thread))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
|
||||
let connection = self.connection.clone();
|
||||
|
||||
self.executor
|
||||
.spawn(async move { Self::save_thread_sync(&connection, id, thread) })
|
||||
}
|
||||
|
||||
pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
|
||||
let connection = self.connection.clone();
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let connection = connection.lock();
|
||||
|
||||
let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
|
||||
DELETE FROM threads WHERE id = ?
|
||||
"})?;
|
||||
|
||||
delete(id.0)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use agent::MessageSegment;
|
||||
use agent::context::LoadedContext;
|
||||
use client::Client;
|
||||
use fs::FakeFs;
|
||||
use gpui::AppContext;
|
||||
use gpui::TestAppContext;
|
||||
use http_client::FakeHttpClient;
|
||||
use language_model::Role;
|
||||
use pretty_assertions::assert_matches;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
ThreadsDatabase::init(cx);
|
||||
|
||||
let http_client = FakeHttpClient::with_404_response();
|
||||
let clock = Arc::new(clock::FakeSystemClock::new());
|
||||
let client = Client::new(clock, http_client, cx);
|
||||
agent::init(cx);
|
||||
agent_settings::init(cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
|
||||
// Save a thread using the old agent.
|
||||
{
|
||||
let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
|
||||
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_message(
|
||||
Role::User,
|
||||
vec![MessageSegment::Text("Hey!".into())],
|
||||
LoadedContext::default(),
|
||||
vec![],
|
||||
false,
|
||||
cx,
|
||||
);
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Text("How're you doing?".into())],
|
||||
LoadedContext::default(),
|
||||
vec![],
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
thread_store
|
||||
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let db = cx
|
||||
.update(|cx| ThreadsDatabase::global_future(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let threads = db.list_threads().await.unwrap();
|
||||
assert_eq!(threads.len(), 1);
|
||||
let thread = db
|
||||
.load_thread(threads[0].id.clone())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
|
||||
assert_eq!(
|
||||
thread.messages[1].to_markdown(),
|
||||
"## Assistant\n\nHow're you doing?\n"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue