From afab4b522e5b4ee9d9b48a4cb473a2b060530eba Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Mon, 9 Jun 2025 14:20:19 +0200 Subject: [PATCH] agent: Add tests for thread serialization code (#32383) This adds some unit tests to ensure that the `update(...)`/migration path to the latest versions works correctly Release Notes: - N/A --- Cargo.lock | 1 + crates/agent/Cargo.toml | 1 + crates/agent/src/thread.rs | 10 +- crates/agent/src/thread_store.rs | 198 +++++++++++++++++++++++++++++-- 4 files changed, 197 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index abf5705a78..c9f91959dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,6 +99,7 @@ dependencies = [ "paths", "picker", "postage", + "pretty_assertions", "project", "prompt_store", "proto", diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index cf0badcff6..66e4a5c78f 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -109,5 +109,6 @@ gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true language = { workspace = true, "features" = ["test-support"] } language_model = { workspace = true, "features" = ["test-support"] } +pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } rand.workspace = true diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index bb8cc706bb..ad0e9260dc 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -195,20 +195,20 @@ impl MessageSegment { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ProjectSnapshot { pub worktree_snapshots: Vec, pub unsaved_buffer_paths: Vec, pub timestamp: DateTime, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct WorktreeSnapshot { pub worktree_path: String, pub git_state: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct GitState { pub remote_url: Option, pub head_sha: Option, @@ -247,7 +247,7 @@ impl LastRestoreCheckpoint { } } -#[derive(Clone, Debug, Default, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub enum DetailedSummaryState { #[default] NotGenerated, @@ -391,7 +391,7 @@ impl ThreadSummary { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ExceededWindowError { /// Model used when last message exceeded context window model_id: LanguageModelId, diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index a86fcda072..620279249e 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -603,7 +603,7 @@ pub struct SerializedThreadMetadata { pub updated_at: DateTime, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct SerializedThread { pub version: String, pub summary: SharedString, @@ -629,7 +629,7 @@ pub struct SerializedThread { pub profile: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct SerializedLanguageModel { pub provider: String, pub model: String, @@ -690,11 +690,15 @@ impl SerializedThreadV0_1_0 { messages.push(message); } - SerializedThread { messages, ..self.0 } + SerializedThread { + messages, + version: SerializedThread::VERSION.to_string(), + ..self.0 + } } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] pub struct SerializedMessage { pub id: MessageId, pub role: Role, @@ -712,7 +716,7 @@ pub struct SerializedMessage { pub is_hidden: bool, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] #[serde(tag = "type")] pub enum SerializedMessageSegment { #[serde(rename = "text")] @@ -730,14 +734,14 @@ pub enum SerializedMessageSegment { }, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] pub struct SerializedToolUse { pub id: LanguageModelToolUseId, pub name: SharedString, pub input: serde_json::Value, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] pub struct SerializedToolResult { pub tool_use_id: LanguageModelToolUseId, pub is_error: bool, @@ -800,7 +804,7 @@ impl LegacySerializedMessage { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] pub struct SerializedCrease { pub start: usize, pub end: usize, @@ -1057,3 +1061,181 @@ impl ThreadsDatabase { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::thread::{DetailedSummaryState, MessageId}; + use chrono::Utc; + use language_model::{Role, TokenUsage}; + use pretty_assertions::assert_eq; + + #[test] + fn test_legacy_serialized_thread_upgrade() { + let updated_at = Utc::now(); + let legacy_thread = LegacySerializedThread { + summary: "Test conversation".into(), + updated_at, + messages: vec![LegacySerializedMessage { + id: MessageId(1), + role: Role::User, + text: "Hello, world!".to_string(), + tool_uses: vec![], + tool_results: vec![], + }], + initial_project_snapshot: None, + }; + + let upgraded = legacy_thread.upgrade(); + + assert_eq!( + upgraded, + SerializedThread { + summary: "Test conversation".into(), + updated_at, + messages: vec![SerializedMessage { + id: MessageId(1), + role: Role::User, + segments: vec![SerializedMessageSegment::Text { + text: "Hello, world!".to_string() + }], + tool_uses: vec![], + tool_results: vec![], + context: "".to_string(), + creases: vec![], + is_hidden: false + }], + version: SerializedThread::VERSION.to_string(), + initial_project_snapshot: None, + cumulative_token_usage: TokenUsage::default(), + request_token_usage: vec![], + detailed_summary_state: DetailedSummaryState::default(), + exceeded_window_error: None, + model: None, + completion_mode: None, + tool_use_limit_reached: false, + profile: None + } + ) + } + + #[test] + fn test_serialized_threadv0_1_0_upgrade() { + let updated_at = Utc::now(); + let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread { + summary: "Test conversation".into(), + updated_at, + messages: vec![ + SerializedMessage { + id: MessageId(1), + role: Role::User, + segments: vec![SerializedMessageSegment::Text { + text: "Use tool_1".to_string(), + }], + tool_uses: vec![], + tool_results: vec![], + context: "".to_string(), + creases: vec![], + is_hidden: false, + }, + SerializedMessage { + id: MessageId(2), + role: Role::Assistant, + segments: vec![SerializedMessageSegment::Text { + text: "I want to use a tool".to_string(), + }], + tool_uses: vec![SerializedToolUse { + id: "abc".into(), + name: "tool_1".into(), + input: serde_json::Value::Null, + }], + tool_results: vec![], + context: "".to_string(), + creases: vec![], + is_hidden: false, + }, + SerializedMessage { + id: MessageId(1), + role: Role::User, + segments: vec![SerializedMessageSegment::Text { + text: "Here is the tool result".to_string(), + }], + tool_uses: vec![], + tool_results: vec![SerializedToolResult { + tool_use_id: "abc".into(), + is_error: false, + content: LanguageModelToolResultContent::Text("abcdef".into()), + output: Some(serde_json::Value::Null), + }], + context: "".to_string(), + creases: vec![], + is_hidden: false, + }, + ], + version: SerializedThreadV0_1_0::VERSION.to_string(), + initial_project_snapshot: None, + cumulative_token_usage: TokenUsage::default(), + request_token_usage: vec![], + detailed_summary_state: DetailedSummaryState::default(), + exceeded_window_error: None, + model: None, + completion_mode: None, + tool_use_limit_reached: false, + profile: None, + }); + let upgraded = thread_v0_1_0.upgrade(); + + assert_eq!( + upgraded, + SerializedThread { + summary: "Test conversation".into(), + updated_at, + messages: vec![ + SerializedMessage { + id: MessageId(1), + role: Role::User, + segments: vec![SerializedMessageSegment::Text { + text: "Use tool_1".to_string() + }], + tool_uses: vec![], + tool_results: vec![], + context: "".to_string(), + creases: vec![], + is_hidden: false + }, + SerializedMessage { + id: MessageId(2), + role: Role::Assistant, + segments: vec![SerializedMessageSegment::Text { + text: "I want to use a tool".to_string(), + }], + tool_uses: vec![SerializedToolUse { + id: "abc".into(), + name: "tool_1".into(), + input: serde_json::Value::Null, + }], + tool_results: vec![SerializedToolResult { + tool_use_id: "abc".into(), + is_error: false, + content: LanguageModelToolResultContent::Text("abcdef".into()), + output: Some(serde_json::Value::Null), + }], + context: "".to_string(), + creases: vec![], + is_hidden: false, + }, + ], + version: SerializedThread::VERSION.to_string(), + initial_project_snapshot: None, + cumulative_token_usage: TokenUsage::default(), + request_token_usage: vec![], + detailed_summary_state: DetailedSummaryState::default(), + exceeded_window_error: None, + model: None, + completion_mode: None, + tool_use_limit_reached: false, + profile: None + } + ) + } +}