agent: Add tests for thread serialization code (#32383)
This adds some unit tests to ensure that the `update(...)`/migration path to the latest versions works correctly Release Notes: - N/A
This commit is contained in:
parent
0cb7dd2972
commit
afab4b522e
4 changed files with 197 additions and 13 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -99,6 +99,7 @@ dependencies = [
|
||||||
"paths",
|
"paths",
|
||||||
"picker",
|
"picker",
|
||||||
"postage",
|
"postage",
|
||||||
|
"pretty_assertions",
|
||||||
"project",
|
"project",
|
||||||
"prompt_store",
|
"prompt_store",
|
||||||
"proto",
|
"proto",
|
||||||
|
|
|
@ -109,5 +109,6 @@ gpui = { workspace = true, "features" = ["test-support"] }
|
||||||
indoc.workspace = true
|
indoc.workspace = true
|
||||||
language = { workspace = true, "features" = ["test-support"] }
|
language = { workspace = true, "features" = ["test-support"] }
|
||||||
language_model = { workspace = true, "features" = ["test-support"] }
|
language_model = { workspace = true, "features" = ["test-support"] }
|
||||||
|
pretty_assertions.workspace = true
|
||||||
project = { workspace = true, features = ["test-support"] }
|
project = { workspace = true, features = ["test-support"] }
|
||||||
rand.workspace = true
|
rand.workspace = true
|
||||||
|
|
|
@ -195,20 +195,20 @@ impl MessageSegment {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct ProjectSnapshot {
|
pub struct ProjectSnapshot {
|
||||||
pub worktree_snapshots: Vec<WorktreeSnapshot>,
|
pub worktree_snapshots: Vec<WorktreeSnapshot>,
|
||||||
pub unsaved_buffer_paths: Vec<String>,
|
pub unsaved_buffer_paths: Vec<String>,
|
||||||
pub timestamp: DateTime<Utc>,
|
pub timestamp: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct WorktreeSnapshot {
|
pub struct WorktreeSnapshot {
|
||||||
pub worktree_path: String,
|
pub worktree_path: String,
|
||||||
pub git_state: Option<GitState>,
|
pub git_state: Option<GitState>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct GitState {
|
pub struct GitState {
|
||||||
pub remote_url: Option<String>,
|
pub remote_url: Option<String>,
|
||||||
pub head_sha: Option<String>,
|
pub head_sha: Option<String>,
|
||||||
|
@ -247,7 +247,7 @@ impl LastRestoreCheckpoint {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||||
pub enum DetailedSummaryState {
|
pub enum DetailedSummaryState {
|
||||||
#[default]
|
#[default]
|
||||||
NotGenerated,
|
NotGenerated,
|
||||||
|
@ -391,7 +391,7 @@ impl ThreadSummary {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct ExceededWindowError {
|
pub struct ExceededWindowError {
|
||||||
/// Model used when last message exceeded context window
|
/// Model used when last message exceeded context window
|
||||||
model_id: LanguageModelId,
|
model_id: LanguageModelId,
|
||||||
|
|
|
@ -603,7 +603,7 @@ pub struct SerializedThreadMetadata {
|
||||||
pub updated_at: DateTime<Utc>,
|
pub updated_at: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||||
pub struct SerializedThread {
|
pub struct SerializedThread {
|
||||||
pub version: String,
|
pub version: String,
|
||||||
pub summary: SharedString,
|
pub summary: SharedString,
|
||||||
|
@ -629,7 +629,7 @@ pub struct SerializedThread {
|
||||||
pub profile: Option<AgentProfileId>,
|
pub profile: Option<AgentProfileId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||||
pub struct SerializedLanguageModel {
|
pub struct SerializedLanguageModel {
|
||||||
pub provider: String,
|
pub provider: String,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
@ -690,11 +690,15 @@ impl SerializedThreadV0_1_0 {
|
||||||
messages.push(message);
|
messages.push(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
SerializedThread { messages, ..self.0 }
|
SerializedThread {
|
||||||
|
messages,
|
||||||
|
version: SerializedThread::VERSION.to_string(),
|
||||||
|
..self.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct SerializedMessage {
|
pub struct SerializedMessage {
|
||||||
pub id: MessageId,
|
pub id: MessageId,
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
|
@ -712,7 +716,7 @@ pub struct SerializedMessage {
|
||||||
pub is_hidden: bool,
|
pub is_hidden: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum SerializedMessageSegment {
|
pub enum SerializedMessageSegment {
|
||||||
#[serde(rename = "text")]
|
#[serde(rename = "text")]
|
||||||
|
@ -730,14 +734,14 @@ pub enum SerializedMessageSegment {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct SerializedToolUse {
|
pub struct SerializedToolUse {
|
||||||
pub id: LanguageModelToolUseId,
|
pub id: LanguageModelToolUseId,
|
||||||
pub name: SharedString,
|
pub name: SharedString,
|
||||||
pub input: serde_json::Value,
|
pub input: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct SerializedToolResult {
|
pub struct SerializedToolResult {
|
||||||
pub tool_use_id: LanguageModelToolUseId,
|
pub tool_use_id: LanguageModelToolUseId,
|
||||||
pub is_error: bool,
|
pub is_error: bool,
|
||||||
|
@ -800,7 +804,7 @@ impl LegacySerializedMessage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct SerializedCrease {
|
pub struct SerializedCrease {
|
||||||
pub start: usize,
|
pub start: usize,
|
||||||
pub end: usize,
|
pub end: usize,
|
||||||
|
@ -1057,3 +1061,181 @@ impl ThreadsDatabase {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::thread::{DetailedSummaryState, MessageId};
|
||||||
|
use chrono::Utc;
|
||||||
|
use language_model::{Role, TokenUsage};
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_legacy_serialized_thread_upgrade() {
|
||||||
|
let updated_at = Utc::now();
|
||||||
|
let legacy_thread = LegacySerializedThread {
|
||||||
|
summary: "Test conversation".into(),
|
||||||
|
updated_at,
|
||||||
|
messages: vec![LegacySerializedMessage {
|
||||||
|
id: MessageId(1),
|
||||||
|
role: Role::User,
|
||||||
|
text: "Hello, world!".to_string(),
|
||||||
|
tool_uses: vec![],
|
||||||
|
tool_results: vec![],
|
||||||
|
}],
|
||||||
|
initial_project_snapshot: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let upgraded = legacy_thread.upgrade();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
upgraded,
|
||||||
|
SerializedThread {
|
||||||
|
summary: "Test conversation".into(),
|
||||||
|
updated_at,
|
||||||
|
messages: vec![SerializedMessage {
|
||||||
|
id: MessageId(1),
|
||||||
|
role: Role::User,
|
||||||
|
segments: vec![SerializedMessageSegment::Text {
|
||||||
|
text: "Hello, world!".to_string()
|
||||||
|
}],
|
||||||
|
tool_uses: vec![],
|
||||||
|
tool_results: vec![],
|
||||||
|
context: "".to_string(),
|
||||||
|
creases: vec![],
|
||||||
|
is_hidden: false
|
||||||
|
}],
|
||||||
|
version: SerializedThread::VERSION.to_string(),
|
||||||
|
initial_project_snapshot: None,
|
||||||
|
cumulative_token_usage: TokenUsage::default(),
|
||||||
|
request_token_usage: vec![],
|
||||||
|
detailed_summary_state: DetailedSummaryState::default(),
|
||||||
|
exceeded_window_error: None,
|
||||||
|
model: None,
|
||||||
|
completion_mode: None,
|
||||||
|
tool_use_limit_reached: false,
|
||||||
|
profile: None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_serialized_threadv0_1_0_upgrade() {
|
||||||
|
let updated_at = Utc::now();
|
||||||
|
let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
|
||||||
|
summary: "Test conversation".into(),
|
||||||
|
updated_at,
|
||||||
|
messages: vec![
|
||||||
|
SerializedMessage {
|
||||||
|
id: MessageId(1),
|
||||||
|
role: Role::User,
|
||||||
|
segments: vec![SerializedMessageSegment::Text {
|
||||||
|
text: "Use tool_1".to_string(),
|
||||||
|
}],
|
||||||
|
tool_uses: vec![],
|
||||||
|
tool_results: vec![],
|
||||||
|
context: "".to_string(),
|
||||||
|
creases: vec![],
|
||||||
|
is_hidden: false,
|
||||||
|
},
|
||||||
|
SerializedMessage {
|
||||||
|
id: MessageId(2),
|
||||||
|
role: Role::Assistant,
|
||||||
|
segments: vec![SerializedMessageSegment::Text {
|
||||||
|
text: "I want to use a tool".to_string(),
|
||||||
|
}],
|
||||||
|
tool_uses: vec![SerializedToolUse {
|
||||||
|
id: "abc".into(),
|
||||||
|
name: "tool_1".into(),
|
||||||
|
input: serde_json::Value::Null,
|
||||||
|
}],
|
||||||
|
tool_results: vec![],
|
||||||
|
context: "".to_string(),
|
||||||
|
creases: vec![],
|
||||||
|
is_hidden: false,
|
||||||
|
},
|
||||||
|
SerializedMessage {
|
||||||
|
id: MessageId(1),
|
||||||
|
role: Role::User,
|
||||||
|
segments: vec![SerializedMessageSegment::Text {
|
||||||
|
text: "Here is the tool result".to_string(),
|
||||||
|
}],
|
||||||
|
tool_uses: vec![],
|
||||||
|
tool_results: vec![SerializedToolResult {
|
||||||
|
tool_use_id: "abc".into(),
|
||||||
|
is_error: false,
|
||||||
|
content: LanguageModelToolResultContent::Text("abcdef".into()),
|
||||||
|
output: Some(serde_json::Value::Null),
|
||||||
|
}],
|
||||||
|
context: "".to_string(),
|
||||||
|
creases: vec![],
|
||||||
|
is_hidden: false,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
version: SerializedThreadV0_1_0::VERSION.to_string(),
|
||||||
|
initial_project_snapshot: None,
|
||||||
|
cumulative_token_usage: TokenUsage::default(),
|
||||||
|
request_token_usage: vec![],
|
||||||
|
detailed_summary_state: DetailedSummaryState::default(),
|
||||||
|
exceeded_window_error: None,
|
||||||
|
model: None,
|
||||||
|
completion_mode: None,
|
||||||
|
tool_use_limit_reached: false,
|
||||||
|
profile: None,
|
||||||
|
});
|
||||||
|
let upgraded = thread_v0_1_0.upgrade();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
upgraded,
|
||||||
|
SerializedThread {
|
||||||
|
summary: "Test conversation".into(),
|
||||||
|
updated_at,
|
||||||
|
messages: vec![
|
||||||
|
SerializedMessage {
|
||||||
|
id: MessageId(1),
|
||||||
|
role: Role::User,
|
||||||
|
segments: vec![SerializedMessageSegment::Text {
|
||||||
|
text: "Use tool_1".to_string()
|
||||||
|
}],
|
||||||
|
tool_uses: vec![],
|
||||||
|
tool_results: vec![],
|
||||||
|
context: "".to_string(),
|
||||||
|
creases: vec![],
|
||||||
|
is_hidden: false
|
||||||
|
},
|
||||||
|
SerializedMessage {
|
||||||
|
id: MessageId(2),
|
||||||
|
role: Role::Assistant,
|
||||||
|
segments: vec![SerializedMessageSegment::Text {
|
||||||
|
text: "I want to use a tool".to_string(),
|
||||||
|
}],
|
||||||
|
tool_uses: vec![SerializedToolUse {
|
||||||
|
id: "abc".into(),
|
||||||
|
name: "tool_1".into(),
|
||||||
|
input: serde_json::Value::Null,
|
||||||
|
}],
|
||||||
|
tool_results: vec![SerializedToolResult {
|
||||||
|
tool_use_id: "abc".into(),
|
||||||
|
is_error: false,
|
||||||
|
content: LanguageModelToolResultContent::Text("abcdef".into()),
|
||||||
|
output: Some(serde_json::Value::Null),
|
||||||
|
}],
|
||||||
|
context: "".to_string(),
|
||||||
|
creases: vec![],
|
||||||
|
is_hidden: false,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
version: SerializedThread::VERSION.to_string(),
|
||||||
|
initial_project_snapshot: None,
|
||||||
|
cumulative_token_usage: TokenUsage::default(),
|
||||||
|
request_token_usage: vec![],
|
||||||
|
detailed_summary_state: DetailedSummaryState::default(),
|
||||||
|
exceeded_window_error: None,
|
||||||
|
model: None,
|
||||||
|
completion_mode: None,
|
||||||
|
tool_use_limit_reached: false,
|
||||||
|
profile: None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue