agent: Add tests for thread serialization code (#32383)

This adds some unit tests to ensure that the `update(...)`/migration
path to the latest versions works correctly

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-06-09 14:20:19 +02:00 committed by GitHub
parent 0cb7dd2972
commit afab4b522e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 197 additions and 13 deletions

1
Cargo.lock generated
View file

@ -99,6 +99,7 @@ dependencies = [
"paths",
"picker",
"postage",
"pretty_assertions",
"project",
"prompt_store",
"proto",

View file

@ -109,5 +109,6 @@ gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true
language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true

View file

@ -195,20 +195,20 @@ impl MessageSegment {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProjectSnapshot {
pub worktree_snapshots: Vec<WorktreeSnapshot>,
pub unsaved_buffer_paths: Vec<String>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WorktreeSnapshot {
pub worktree_path: String,
pub git_state: Option<GitState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GitState {
pub remote_url: Option<String>,
pub head_sha: Option<String>,
@ -247,7 +247,7 @@ impl LastRestoreCheckpoint {
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum DetailedSummaryState {
#[default]
NotGenerated,
@ -391,7 +391,7 @@ impl ThreadSummary {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ExceededWindowError {
/// Model used when last message exceeded context window
model_id: LanguageModelId,

View file

@ -603,7 +603,7 @@ pub struct SerializedThreadMetadata {
pub updated_at: DateTime<Utc>,
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct SerializedThread {
pub version: String,
pub summary: SharedString,
@ -629,7 +629,7 @@ pub struct SerializedThread {
pub profile: Option<AgentProfileId>,
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct SerializedLanguageModel {
pub provider: String,
pub model: String,
@ -690,11 +690,15 @@ impl SerializedThreadV0_1_0 {
messages.push(message);
}
SerializedThread { messages, ..self.0 }
SerializedThread {
messages,
version: SerializedThread::VERSION.to_string(),
..self.0
}
}
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedMessage {
pub id: MessageId,
pub role: Role,
@ -712,7 +716,7 @@ pub struct SerializedMessage {
pub is_hidden: bool,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum SerializedMessageSegment {
#[serde(rename = "text")]
@ -730,14 +734,14 @@ pub enum SerializedMessageSegment {
},
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub input: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool,
@ -800,7 +804,7 @@ impl LegacySerializedMessage {
}
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct SerializedCrease {
pub start: usize,
pub end: usize,
@ -1057,3 +1061,181 @@ impl ThreadsDatabase {
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::thread::{DetailedSummaryState, MessageId};
use chrono::Utc;
use language_model::{Role, TokenUsage};
use pretty_assertions::assert_eq;
#[test]
fn test_legacy_serialized_thread_upgrade() {
let updated_at = Utc::now();
let legacy_thread = LegacySerializedThread {
summary: "Test conversation".into(),
updated_at,
messages: vec![LegacySerializedMessage {
id: MessageId(1),
role: Role::User,
text: "Hello, world!".to_string(),
tool_uses: vec![],
tool_results: vec![],
}],
initial_project_snapshot: None,
};
let upgraded = legacy_thread.upgrade();
assert_eq!(
upgraded,
SerializedThread {
summary: "Test conversation".into(),
updated_at,
messages: vec![SerializedMessage {
id: MessageId(1),
role: Role::User,
segments: vec![SerializedMessageSegment::Text {
text: "Hello, world!".to_string()
}],
tool_uses: vec![],
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false
}],
version: SerializedThread::VERSION.to_string(),
initial_project_snapshot: None,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: vec![],
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
model: None,
completion_mode: None,
tool_use_limit_reached: false,
profile: None
}
)
}
#[test]
fn test_serialized_threadv0_1_0_upgrade() {
let updated_at = Utc::now();
let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
summary: "Test conversation".into(),
updated_at,
messages: vec![
SerializedMessage {
id: MessageId(1),
role: Role::User,
segments: vec![SerializedMessageSegment::Text {
text: "Use tool_1".to_string(),
}],
tool_uses: vec![],
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
SerializedMessage {
id: MessageId(2),
role: Role::Assistant,
segments: vec![SerializedMessageSegment::Text {
text: "I want to use a tool".to_string(),
}],
tool_uses: vec![SerializedToolUse {
id: "abc".into(),
name: "tool_1".into(),
input: serde_json::Value::Null,
}],
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
SerializedMessage {
id: MessageId(1),
role: Role::User,
segments: vec![SerializedMessageSegment::Text {
text: "Here is the tool result".to_string(),
}],
tool_uses: vec![],
tool_results: vec![SerializedToolResult {
tool_use_id: "abc".into(),
is_error: false,
content: LanguageModelToolResultContent::Text("abcdef".into()),
output: Some(serde_json::Value::Null),
}],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
],
version: SerializedThreadV0_1_0::VERSION.to_string(),
initial_project_snapshot: None,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: vec![],
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
model: None,
completion_mode: None,
tool_use_limit_reached: false,
profile: None,
});
let upgraded = thread_v0_1_0.upgrade();
assert_eq!(
upgraded,
SerializedThread {
summary: "Test conversation".into(),
updated_at,
messages: vec![
SerializedMessage {
id: MessageId(1),
role: Role::User,
segments: vec![SerializedMessageSegment::Text {
text: "Use tool_1".to_string()
}],
tool_uses: vec![],
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false
},
SerializedMessage {
id: MessageId(2),
role: Role::Assistant,
segments: vec![SerializedMessageSegment::Text {
text: "I want to use a tool".to_string(),
}],
tool_uses: vec![SerializedToolUse {
id: "abc".into(),
name: "tool_1".into(),
input: serde_json::Value::Null,
}],
tool_results: vec![SerializedToolResult {
tool_use_id: "abc".into(),
is_error: false,
content: LanguageModelToolResultContent::Text("abcdef".into()),
output: Some(serde_json::Value::Null),
}],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
],
version: SerializedThread::VERSION.to_string(),
initial_project_snapshot: None,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: vec![],
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
model: None,
completion_mode: None,
tool_use_limit_reached: false,
profile: None
}
)
}
}