Start on a new db module
This commit is contained in:
parent
846ed6adf9
commit
6b6b7e66e1
7 changed files with 180 additions and 7 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -191,6 +191,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"acp_thread",
|
||||
"action_log",
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"agent_servers",
|
||||
"agent_settings",
|
||||
|
|
|
@ -4,11 +4,12 @@ use anyhow::Result;
|
|||
use collections::IndexMap;
|
||||
use gpui::{Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub struct UserMessageId(Arc<str>);
|
||||
|
||||
impl UserMessageId {
|
||||
|
|
|
@ -2,6 +2,7 @@ use agent::ThreadId;
|
|||
use anyhow::{Context as _, Result, bail};
|
||||
use file_icons::FileIcons;
|
||||
use prompt_store::{PromptId, UserPromptId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt,
|
||||
ops::Range,
|
||||
|
@ -11,7 +12,7 @@ use std::{
|
|||
use ui::{App, IconName, SharedString};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MentionUri {
|
||||
File {
|
||||
abs_path: PathBuf,
|
||||
|
|
|
@ -17,6 +17,7 @@ action_log.workspace = true
|
|||
agent-client-protocol.workspace = true
|
||||
agent_servers.workspace = true
|
||||
agent_settings.workspace = true
|
||||
agent.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
mod agent;
|
||||
mod db;
|
||||
mod native_agent_server;
|
||||
mod templates;
|
||||
mod thread;
|
||||
|
@ -8,6 +9,7 @@ mod tools;
|
|||
mod tests;
|
||||
|
||||
pub use agent::*;
|
||||
pub use db::*;
|
||||
pub use native_agent_server::NativeAgentServer;
|
||||
pub use templates::*;
|
||||
pub use thread::*;
|
||||
|
|
167
crates/agent2/src/db.rs
Normal file
167
crates/agent2/src/db.rs
Normal file
|
@ -0,0 +1,167 @@
|
|||
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
||||
use agent::thread_store;
|
||||
use agent_settings::{AgentProfileId, CompletionMode};
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{HashMap, IndexMap};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::SharedString;
|
||||
|
||||
pub type DbMessage = crate::Message;
|
||||
pub type DbSummary = agent::thread::DetailedSummaryState;
|
||||
pub type DbLanguageModel = thread_store::SerializedLanguageModel;
|
||||
pub type DbThreadMetadata = thread_store::SerializedThreadMetadata;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DbThread {
|
||||
pub title: SharedString,
|
||||
pub messages: Vec<DbMessage>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
pub summary: DbSummary,
|
||||
#[serde(default)]
|
||||
pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
|
||||
#[serde(default)]
|
||||
pub cumulative_token_usage: language_model::TokenUsage,
|
||||
#[serde(default)]
|
||||
pub request_token_usage: Vec<language_model::TokenUsage>,
|
||||
#[serde(default)]
|
||||
pub model: Option<DbLanguageModel>,
|
||||
#[serde(default)]
|
||||
pub completion_mode: Option<CompletionMode>,
|
||||
#[serde(default)]
|
||||
pub profile: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
impl DbThread {
|
||||
pub const VERSION: &'static str = "0.3.0";
|
||||
|
||||
pub fn from_json(json: &[u8]) -> Result<Self> {
|
||||
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
||||
match saved_thread_json.get("version") {
|
||||
Some(serde_json::Value::String(version)) => match version.as_str() {
|
||||
Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
|
||||
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
|
||||
},
|
||||
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
|
||||
}
|
||||
}
|
||||
|
||||
fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
|
||||
let mut messages = Vec::new();
|
||||
for msg in thread.messages {
|
||||
let message = match msg.role {
|
||||
language_model::Role::User => {
|
||||
let mut content = Vec::new();
|
||||
|
||||
// Convert segments to content
|
||||
for segment in msg.segments {
|
||||
match segment {
|
||||
thread_store::SerializedMessageSegment::Text { text } => {
|
||||
content.push(UserMessageContent::Text(text));
|
||||
}
|
||||
thread_store::SerializedMessageSegment::Thinking { text, .. } => {
|
||||
// User messages don't have thinking segments, but handle gracefully
|
||||
content.push(UserMessageContent::Text(text));
|
||||
}
|
||||
thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
|
||||
// User messages don't have redacted thinking, skip.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no content was added, add context as text if available
|
||||
if content.is_empty() && !msg.context.is_empty() {
|
||||
content.push(UserMessageContent::Text(msg.context));
|
||||
}
|
||||
|
||||
crate::Message::User(UserMessage {
|
||||
// MessageId from old format can't be meaningfully converted, so generate a new one
|
||||
id: acp_thread::UserMessageId::new(),
|
||||
content,
|
||||
})
|
||||
}
|
||||
language_model::Role::Assistant => {
|
||||
let mut content = Vec::new();
|
||||
|
||||
// Convert segments to content
|
||||
for segment in msg.segments {
|
||||
match segment {
|
||||
thread_store::SerializedMessageSegment::Text { text } => {
|
||||
content.push(AgentMessageContent::Text(text));
|
||||
}
|
||||
thread_store::SerializedMessageSegment::Thinking {
|
||||
text,
|
||||
signature,
|
||||
} => {
|
||||
content.push(AgentMessageContent::Thinking { text, signature });
|
||||
}
|
||||
thread_store::SerializedMessageSegment::RedactedThinking { data } => {
|
||||
content.push(AgentMessageContent::RedactedThinking(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool uses
|
||||
let mut tool_names_by_id = HashMap::default();
|
||||
for tool_use in msg.tool_uses {
|
||||
tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
|
||||
content.push(AgentMessageContent::ToolUse(
|
||||
language_model::LanguageModelToolUse {
|
||||
id: tool_use.id,
|
||||
name: tool_use.name.into(),
|
||||
raw_input: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
input: tool_use.input,
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
// Convert tool results
|
||||
let mut tool_results = IndexMap::default();
|
||||
for tool_result in msg.tool_results {
|
||||
let name = tool_names_by_id
|
||||
.remove(&tool_result.tool_use_id)
|
||||
.unwrap_or_else(|| SharedString::from("unknown"));
|
||||
tool_results.insert(
|
||||
tool_result.tool_use_id.clone(),
|
||||
language_model::LanguageModelToolResult {
|
||||
tool_use_id: tool_result.tool_use_id,
|
||||
tool_name: name.into(),
|
||||
is_error: tool_result.is_error,
|
||||
content: tool_result.content,
|
||||
output: tool_result.output,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
crate::Message::Agent(AgentMessage {
|
||||
content,
|
||||
tool_results,
|
||||
})
|
||||
}
|
||||
language_model::Role::System => {
|
||||
// Skip system messages as they're not supported in the new format
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
title: thread.summary,
|
||||
messages,
|
||||
updated_at: thread.updated_at,
|
||||
summary: thread.detailed_summary_state,
|
||||
initial_project_snapshot: thread.initial_project_snapshot,
|
||||
cumulative_token_usage: thread.cumulative_token_usage,
|
||||
request_token_usage: thread.request_token_usage,
|
||||
model: thread.model,
|
||||
completion_mode: thread.completion_mode,
|
||||
profile: thread.profile,
|
||||
})
|
||||
}
|
||||
}
|
|
@ -29,7 +29,7 @@ use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
|
|||
use std::{fmt::Write, ops::Range};
|
||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Message {
|
||||
User(UserMessage),
|
||||
Agent(AgentMessage),
|
||||
|
@ -53,13 +53,13 @@ impl Message {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct UserMessage {
|
||||
pub id: UserMessageId,
|
||||
pub content: Vec<UserMessageContent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum UserMessageContent {
|
||||
Text(String),
|
||||
Mention { uri: MentionUri, content: String },
|
||||
|
@ -376,13 +376,13 @@ impl AgentMessage {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct AgentMessage {
|
||||
pub content: Vec<AgentMessageContent>,
|
||||
pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum AgentMessageContent {
|
||||
Text(String),
|
||||
Thinking {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue