agent: Add tests for thread serialization code (#32383)

This adds some unit tests to ensure that the `update(...)`/migration
path to the latest versions works correctly

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-06-09 14:20:19 +02:00 committed by GitHub
parent 0cb7dd2972
commit afab4b522e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 197 additions and 13 deletions

1
Cargo.lock generated
View file

@ -99,6 +99,7 @@ dependencies = [
"paths", "paths",
"picker", "picker",
"postage", "postage",
"pretty_assertions",
"project", "project",
"prompt_store", "prompt_store",
"proto", "proto",

View file

@ -109,5 +109,6 @@ gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true indoc.workspace = true
language = { workspace = true, "features" = ["test-support"] } language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] } language_model = { workspace = true, "features" = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] }
rand.workspace = true rand.workspace = true

View file

@ -195,20 +195,20 @@ impl MessageSegment {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProjectSnapshot { pub struct ProjectSnapshot {
pub worktree_snapshots: Vec<WorktreeSnapshot>, pub worktree_snapshots: Vec<WorktreeSnapshot>,
pub unsaved_buffer_paths: Vec<String>, pub unsaved_buffer_paths: Vec<String>,
pub timestamp: DateTime<Utc>, pub timestamp: DateTime<Utc>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WorktreeSnapshot { pub struct WorktreeSnapshot {
pub worktree_path: String, pub worktree_path: String,
pub git_state: Option<GitState>, pub git_state: Option<GitState>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GitState { pub struct GitState {
pub remote_url: Option<String>, pub remote_url: Option<String>,
pub head_sha: Option<String>, pub head_sha: Option<String>,
@ -247,7 +247,7 @@ impl LastRestoreCheckpoint {
} }
} }
#[derive(Clone, Debug, Default, Serialize, Deserialize)] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum DetailedSummaryState { pub enum DetailedSummaryState {
#[default] #[default]
NotGenerated, NotGenerated,
@ -391,7 +391,7 @@ impl ThreadSummary {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ExceededWindowError { pub struct ExceededWindowError {
/// Model used when last message exceeded context window /// Model used when last message exceeded context window
model_id: LanguageModelId, model_id: LanguageModelId,

View file

@ -603,7 +603,7 @@ pub struct SerializedThreadMetadata {
pub updated_at: DateTime<Utc>, pub updated_at: DateTime<Utc>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct SerializedThread { pub struct SerializedThread {
pub version: String, pub version: String,
pub summary: SharedString, pub summary: SharedString,
@ -629,7 +629,7 @@ pub struct SerializedThread {
pub profile: Option<AgentProfileId>, pub profile: Option<AgentProfileId>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct SerializedLanguageModel { pub struct SerializedLanguageModel {
pub provider: String, pub provider: String,
pub model: String, pub model: String,
@ -690,11 +690,15 @@ impl SerializedThreadV0_1_0 {
messages.push(message); messages.push(message);
} }
SerializedThread { messages, ..self.0 } SerializedThread {
messages,
version: SerializedThread::VERSION.to_string(),
..self.0
}
} }
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedMessage { pub struct SerializedMessage {
pub id: MessageId, pub id: MessageId,
pub role: Role, pub role: Role,
@ -712,7 +716,7 @@ pub struct SerializedMessage {
pub is_hidden: bool, pub is_hidden: bool,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum SerializedMessageSegment { pub enum SerializedMessageSegment {
#[serde(rename = "text")] #[serde(rename = "text")]
@ -730,14 +734,14 @@ pub enum SerializedMessageSegment {
}, },
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedToolUse { pub struct SerializedToolUse {
pub id: LanguageModelToolUseId, pub id: LanguageModelToolUseId,
pub name: SharedString, pub name: SharedString,
pub input: serde_json::Value, pub input: serde_json::Value,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedToolResult { pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId, pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool, pub is_error: bool,
@ -800,7 +804,7 @@ impl LegacySerializedMessage {
} }
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedCrease { pub struct SerializedCrease {
pub start: usize, pub start: usize,
pub end: usize, pub end: usize,
@ -1057,3 +1061,181 @@ impl ThreadsDatabase {
}) })
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::thread::{DetailedSummaryState, MessageId};
use chrono::Utc;
use language_model::{Role, TokenUsage};
use pretty_assertions::assert_eq;
#[test]
fn test_legacy_serialized_thread_upgrade() {
let updated_at = Utc::now();
let legacy_thread = LegacySerializedThread {
summary: "Test conversation".into(),
updated_at,
messages: vec![LegacySerializedMessage {
id: MessageId(1),
role: Role::User,
text: "Hello, world!".to_string(),
tool_uses: vec![],
tool_results: vec![],
}],
initial_project_snapshot: None,
};
let upgraded = legacy_thread.upgrade();
assert_eq!(
upgraded,
SerializedThread {
summary: "Test conversation".into(),
updated_at,
messages: vec![SerializedMessage {
id: MessageId(1),
role: Role::User,
segments: vec![SerializedMessageSegment::Text {
text: "Hello, world!".to_string()
}],
tool_uses: vec![],
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false
}],
version: SerializedThread::VERSION.to_string(),
initial_project_snapshot: None,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: vec![],
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
model: None,
completion_mode: None,
tool_use_limit_reached: false,
profile: None
}
)
}
#[test]
fn test_serialized_threadv0_1_0_upgrade() {
let updated_at = Utc::now();
let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
summary: "Test conversation".into(),
updated_at,
messages: vec![
SerializedMessage {
id: MessageId(1),
role: Role::User,
segments: vec![SerializedMessageSegment::Text {
text: "Use tool_1".to_string(),
}],
tool_uses: vec![],
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
SerializedMessage {
id: MessageId(2),
role: Role::Assistant,
segments: vec![SerializedMessageSegment::Text {
text: "I want to use a tool".to_string(),
}],
tool_uses: vec![SerializedToolUse {
id: "abc".into(),
name: "tool_1".into(),
input: serde_json::Value::Null,
}],
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
SerializedMessage {
id: MessageId(1),
role: Role::User,
segments: vec![SerializedMessageSegment::Text {
text: "Here is the tool result".to_string(),
}],
tool_uses: vec![],
tool_results: vec![SerializedToolResult {
tool_use_id: "abc".into(),
is_error: false,
content: LanguageModelToolResultContent::Text("abcdef".into()),
output: Some(serde_json::Value::Null),
}],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
],
version: SerializedThreadV0_1_0::VERSION.to_string(),
initial_project_snapshot: None,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: vec![],
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
model: None,
completion_mode: None,
tool_use_limit_reached: false,
profile: None,
});
let upgraded = thread_v0_1_0.upgrade();
assert_eq!(
upgraded,
SerializedThread {
summary: "Test conversation".into(),
updated_at,
messages: vec![
SerializedMessage {
id: MessageId(1),
role: Role::User,
segments: vec![SerializedMessageSegment::Text {
text: "Use tool_1".to_string()
}],
tool_uses: vec![],
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false
},
SerializedMessage {
id: MessageId(2),
role: Role::Assistant,
segments: vec![SerializedMessageSegment::Text {
text: "I want to use a tool".to_string(),
}],
tool_uses: vec![SerializedToolUse {
id: "abc".into(),
name: "tool_1".into(),
input: serde_json::Value::Null,
}],
tool_results: vec![SerializedToolResult {
tool_use_id: "abc".into(),
is_error: false,
content: LanguageModelToolResultContent::Text("abcdef".into()),
output: Some(serde_json::Value::Null),
}],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
],
version: SerializedThread::VERSION.to_string(),
initial_project_snapshot: None,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: vec![],
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
model: None,
completion_mode: None,
tool_use_limit_reached: false,
profile: None
}
)
}
}