Re-add history entries for native agent threads (#36500)
Closes #ISSUE Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
6b6eb11643
commit
6ba52a3a42
16 changed files with 2007 additions and 119 deletions
|
@ -19,6 +19,7 @@ agent-client-protocol.workspace = true
|
|||
agent_servers.workspace = true
|
||||
agent_settings.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_context.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
|
@ -39,6 +40,7 @@ language_model.workspace = true
|
|||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
open.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
portable-pty.workspace = true
|
||||
project.workspace = true
|
||||
|
@ -49,6 +51,7 @@ serde.workspace = true
|
|||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
sqlez.workspace = true
|
||||
task.workspace = true
|
||||
terminal.workspace = true
|
||||
text.workspace = true
|
||||
|
@ -59,6 +62,7 @@ watch.workspace = true
|
|||
web_search.workspace = true
|
||||
which.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zstd.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
agent = { workspace = true, "features" = ["test-support"] }
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
use crate::{
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
|
||||
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
|
||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
|
||||
UserMessageContent, WebSearchTool, templates::Templates,
|
||||
ContextServerRegistry, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent,
|
||||
templates::Templates,
|
||||
};
|
||||
use acp_thread::AgentModelSelector;
|
||||
use crate::{HistoryStore, ThreadsDatabase};
|
||||
use acp_thread::{AcpThread, AgentModelSelector};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
|
@ -51,7 +50,8 @@ struct Session {
|
|||
thread: Entity<Thread>,
|
||||
/// The ACP thread that handles protocol communication
|
||||
acp_thread: WeakEntity<acp_thread::AcpThread>,
|
||||
_subscription: Subscription,
|
||||
pending_save: Task<()>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
pub struct LanguageModels {
|
||||
|
@ -155,6 +155,7 @@ impl LanguageModels {
|
|||
pub struct NativeAgent {
|
||||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
history: Entity<HistoryStore>,
|
||||
/// Shared project context for all threads
|
||||
project_context: Entity<ProjectContext>,
|
||||
project_context_needs_refresh: watch::Sender<()>,
|
||||
|
@ -173,6 +174,7 @@ pub struct NativeAgent {
|
|||
impl NativeAgent {
|
||||
pub async fn new(
|
||||
project: Entity<Project>,
|
||||
history: Entity<HistoryStore>,
|
||||
templates: Arc<Templates>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
|
@ -200,6 +202,7 @@ impl NativeAgent {
|
|||
watch::channel(());
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
history,
|
||||
project_context: cx.new(|_| project_context),
|
||||
project_context_needs_refresh: project_context_needs_refresh_tx,
|
||||
_maintain_project_context: cx.spawn(async move |this, cx| {
|
||||
|
@ -218,6 +221,55 @@ impl NativeAgent {
|
|||
})
|
||||
}
|
||||
|
||||
fn register_session(
|
||||
&mut self,
|
||||
thread_handle: Entity<Thread>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Entity<AcpThread> {
|
||||
let connection = Rc::new(NativeAgentConnection(cx.entity()));
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let summarization_model = registry.thread_summary_model().map(|c| c.model);
|
||||
|
||||
thread_handle.update(cx, |thread, cx| {
|
||||
thread.set_summarization_model(summarization_model, cx);
|
||||
thread.add_default_tools(cx)
|
||||
});
|
||||
|
||||
let thread = thread_handle.read(cx);
|
||||
let session_id = thread.id().clone();
|
||||
let title = thread.title();
|
||||
let project = thread.project.clone();
|
||||
let action_log = thread.action_log.clone();
|
||||
let acp_thread = cx.new(|_cx| {
|
||||
acp_thread::AcpThread::new(
|
||||
title,
|
||||
connection,
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
session_id.clone(),
|
||||
)
|
||||
});
|
||||
let subscriptions = vec![
|
||||
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
}),
|
||||
cx.observe(&thread_handle, move |this, thread, cx| {
|
||||
this.save_thread(thread.clone(), cx)
|
||||
}),
|
||||
];
|
||||
|
||||
self.sessions.insert(
|
||||
session_id,
|
||||
Session {
|
||||
thread: thread_handle,
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
_subscriptions: subscriptions,
|
||||
pending_save: Task::ready(()),
|
||||
},
|
||||
);
|
||||
acp_thread
|
||||
}
|
||||
|
||||
pub fn models(&self) -> &LanguageModels {
|
||||
&self.models
|
||||
}
|
||||
|
@ -444,6 +496,63 @@ impl NativeAgent {
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open_thread(
|
||||
&mut self,
|
||||
id: acp::SessionId,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let database_future = ThreadsDatabase::connect(cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let database = database_future.await.map_err(|err| anyhow!(err))?;
|
||||
let db_thread = database
|
||||
.load_thread(id.clone())
|
||||
.await?
|
||||
.with_context(|| format!("no thread found with ID: {id:?}"))?;
|
||||
|
||||
let thread = this.update(cx, |this, cx| {
|
||||
let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
|
||||
cx.new(|cx| {
|
||||
Thread::from_db(
|
||||
id.clone(),
|
||||
db_thread,
|
||||
this.project.clone(),
|
||||
this.project_context.clone(),
|
||||
this.context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
this.templates.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?;
|
||||
let acp_thread =
|
||||
this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
|
||||
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
|
||||
cx.update(|cx| {
|
||||
NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
|
||||
})?
|
||||
.await?;
|
||||
Ok(acp_thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
|
||||
let database_future = ThreadsDatabase::connect(cx);
|
||||
let (id, db_thread) =
|
||||
thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
|
||||
let Some(session) = self.sessions.get_mut(&id) else {
|
||||
return;
|
||||
};
|
||||
let history = self.history.clone();
|
||||
session.pending_save = cx.spawn(async move |_, cx| {
|
||||
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
|
||||
return;
|
||||
};
|
||||
let db_thread = db_thread.await;
|
||||
database.save_thread(id, db_thread).await.log_err();
|
||||
history.update(cx, |history, cx| history.reload(cx)).ok();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper struct that implements the AgentConnection trait
|
||||
|
@ -476,13 +585,21 @@ impl NativeAgentConnection {
|
|||
};
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
let mut response_stream = match f(thread, cx) {
|
||||
let response_stream = match f(thread, cx) {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
};
|
||||
Self::handle_thread_events(response_stream, acp_thread, cx)
|
||||
}
|
||||
|
||||
fn handle_thread_events(
|
||||
mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
|
||||
acp_thread: WeakEntity<AcpThread>,
|
||||
cx: &App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
cx.spawn(async move |cx| {
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
while let Some(result) = events.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
|
@ -686,8 +803,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||
// Fetch default model from registry settings
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let language_registry = project.read(cx).languages().clone();
|
||||
|
||||
// Log available models for debugging
|
||||
let available_count = registry.available_models(cx).count();
|
||||
log::debug!("Total available models: {}", available_count);
|
||||
|
@ -697,72 +812,23 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
.models
|
||||
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
||||
});
|
||||
let summarization_model = registry.thread_summary_model().map(|c| c.model);
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
agent.project_context.clone(),
|
||||
agent.context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
default_model,
|
||||
summarization_model,
|
||||
cx,
|
||||
);
|
||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
|
||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(GrepTool::new(project.clone()));
|
||||
thread.add_tool(ListDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(MovePathTool::new(project.clone()));
|
||||
thread.add_tool(NowTool);
|
||||
thread.add_tool(OpenTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
|
||||
thread.add_tool(TerminalTool::new(project.clone(), cx));
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
|
||||
thread
|
||||
)
|
||||
});
|
||||
|
||||
Ok(thread)
|
||||
},
|
||||
)??;
|
||||
|
||||
let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|_cx| {
|
||||
acp_thread::AcpThread::new(
|
||||
"agent2",
|
||||
self.clone(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
session_id.clone(),
|
||||
)
|
||||
})
|
||||
})?;
|
||||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, cx| {
|
||||
agent.sessions.insert(
|
||||
session_id,
|
||||
Session {
|
||||
thread,
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
}),
|
||||
},
|
||||
);
|
||||
})?;
|
||||
|
||||
Ok(acp_thread)
|
||||
agent.update(cx, |agent, cx| agent.register_session(thread, cx))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -887,8 +953,11 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
|
||||
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
history_store,
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
|
@ -942,9 +1011,12 @@ mod tests {
|
|||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/", json!({ "a": {} })).await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
|
||||
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
|
||||
let connection = NativeAgentConnection(
|
||||
NativeAgent::new(
|
||||
project.clone(),
|
||||
history_store,
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
|
@ -995,9 +1067,13 @@ mod tests {
|
|||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
|
||||
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
|
||||
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
|
||||
|
||||
// Create the agent and connection
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
history_store,
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
mod agent;
|
||||
mod db;
|
||||
mod history_store;
|
||||
mod native_agent_server;
|
||||
mod templates;
|
||||
mod thread;
|
||||
|
@ -9,6 +11,8 @@ mod tools;
|
|||
mod tests;
|
||||
|
||||
pub use agent::*;
|
||||
pub use db::*;
|
||||
pub use history_store::*;
|
||||
pub use native_agent_server::NativeAgentServer;
|
||||
pub use templates::*;
|
||||
pub use thread::*;
|
||||
|
|
470
crates/agent2/src/db.rs
Normal file
470
crates/agent2/src/db.rs
Normal file
|
@ -0,0 +1,470 @@
|
|||
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
||||
use agent::thread_store;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{AgentProfileId, CompletionMode};
|
||||
use anyhow::{Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{HashMap, IndexMap};
|
||||
use futures::{FutureExt, future::Shared};
|
||||
use gpui::{BackgroundExecutor, Global, Task};
|
||||
use indoc::indoc;
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlez::{
|
||||
bindable::{Bind, Column},
|
||||
connection::Connection,
|
||||
statement::Statement,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use ui::{App, SharedString};
|
||||
|
||||
pub type DbMessage = crate::Message;
|
||||
pub type DbSummary = agent::thread::DetailedSummaryState;
|
||||
pub type DbLanguageModel = thread_store::SerializedLanguageModel;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbThreadMetadata {
|
||||
pub id: acp::SessionId,
|
||||
#[serde(alias = "summary")]
|
||||
pub title: SharedString,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DbThread {
|
||||
pub title: SharedString,
|
||||
pub messages: Vec<DbMessage>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
pub summary: DbSummary,
|
||||
#[serde(default)]
|
||||
pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
|
||||
#[serde(default)]
|
||||
pub cumulative_token_usage: language_model::TokenUsage,
|
||||
#[serde(default)]
|
||||
pub request_token_usage: Vec<language_model::TokenUsage>,
|
||||
#[serde(default)]
|
||||
pub model: Option<DbLanguageModel>,
|
||||
#[serde(default)]
|
||||
pub completion_mode: Option<CompletionMode>,
|
||||
#[serde(default)]
|
||||
pub profile: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
impl DbThread {
|
||||
pub const VERSION: &'static str = "0.3.0";
|
||||
|
||||
pub fn from_json(json: &[u8]) -> Result<Self> {
|
||||
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
||||
match saved_thread_json.get("version") {
|
||||
Some(serde_json::Value::String(version)) => match version.as_str() {
|
||||
Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
|
||||
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
|
||||
},
|
||||
_ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
|
||||
}
|
||||
}
|
||||
|
||||
fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
|
||||
let mut messages = Vec::new();
|
||||
for msg in thread.messages {
|
||||
let message = match msg.role {
|
||||
language_model::Role::User => {
|
||||
let mut content = Vec::new();
|
||||
|
||||
// Convert segments to content
|
||||
for segment in msg.segments {
|
||||
match segment {
|
||||
thread_store::SerializedMessageSegment::Text { text } => {
|
||||
content.push(UserMessageContent::Text(text));
|
||||
}
|
||||
thread_store::SerializedMessageSegment::Thinking { text, .. } => {
|
||||
// User messages don't have thinking segments, but handle gracefully
|
||||
content.push(UserMessageContent::Text(text));
|
||||
}
|
||||
thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
|
||||
// User messages don't have redacted thinking, skip.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no content was added, add context as text if available
|
||||
if content.is_empty() && !msg.context.is_empty() {
|
||||
content.push(UserMessageContent::Text(msg.context));
|
||||
}
|
||||
|
||||
crate::Message::User(UserMessage {
|
||||
// MessageId from old format can't be meaningfully converted, so generate a new one
|
||||
id: acp_thread::UserMessageId::new(),
|
||||
content,
|
||||
})
|
||||
}
|
||||
language_model::Role::Assistant => {
|
||||
let mut content = Vec::new();
|
||||
|
||||
// Convert segments to content
|
||||
for segment in msg.segments {
|
||||
match segment {
|
||||
thread_store::SerializedMessageSegment::Text { text } => {
|
||||
content.push(AgentMessageContent::Text(text));
|
||||
}
|
||||
thread_store::SerializedMessageSegment::Thinking {
|
||||
text,
|
||||
signature,
|
||||
} => {
|
||||
content.push(AgentMessageContent::Thinking { text, signature });
|
||||
}
|
||||
thread_store::SerializedMessageSegment::RedactedThinking { data } => {
|
||||
content.push(AgentMessageContent::RedactedThinking(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool uses
|
||||
let mut tool_names_by_id = HashMap::default();
|
||||
for tool_use in msg.tool_uses {
|
||||
tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
|
||||
content.push(AgentMessageContent::ToolUse(
|
||||
language_model::LanguageModelToolUse {
|
||||
id: tool_use.id,
|
||||
name: tool_use.name.into(),
|
||||
raw_input: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
input: tool_use.input,
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
// Convert tool results
|
||||
let mut tool_results = IndexMap::default();
|
||||
for tool_result in msg.tool_results {
|
||||
let name = tool_names_by_id
|
||||
.remove(&tool_result.tool_use_id)
|
||||
.unwrap_or_else(|| SharedString::from("unknown"));
|
||||
tool_results.insert(
|
||||
tool_result.tool_use_id.clone(),
|
||||
language_model::LanguageModelToolResult {
|
||||
tool_use_id: tool_result.tool_use_id,
|
||||
tool_name: name.into(),
|
||||
is_error: tool_result.is_error,
|
||||
content: tool_result.content,
|
||||
output: tool_result.output,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
crate::Message::Agent(AgentMessage {
|
||||
content,
|
||||
tool_results,
|
||||
})
|
||||
}
|
||||
language_model::Role::System => {
|
||||
// Skip system messages as they're not supported in the new format
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
title: thread.summary,
|
||||
messages,
|
||||
updated_at: thread.updated_at,
|
||||
summary: thread.detailed_summary_state,
|
||||
initial_project_snapshot: thread.initial_project_snapshot,
|
||||
cumulative_token_usage: thread.cumulative_token_usage,
|
||||
request_token_usage: thread.request_token_usage,
|
||||
model: thread.model,
|
||||
completion_mode: thread.completion_mode,
|
||||
profile: thread.profile,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub static ZED_STATELESS: std::sync::LazyLock<bool> =
|
||||
std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DataType {
|
||||
#[serde(rename = "json")]
|
||||
Json,
|
||||
#[serde(rename = "zstd")]
|
||||
Zstd,
|
||||
}
|
||||
|
||||
impl Bind for DataType {
|
||||
fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
|
||||
let value = match self {
|
||||
DataType::Json => "json",
|
||||
DataType::Zstd => "zstd",
|
||||
};
|
||||
value.bind(statement, start_index)
|
||||
}
|
||||
}
|
||||
|
||||
impl Column for DataType {
|
||||
fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
|
||||
let (value, next_index) = String::column(statement, start_index)?;
|
||||
let data_type = match value.as_str() {
|
||||
"json" => DataType::Json,
|
||||
"zstd" => DataType::Zstd,
|
||||
_ => anyhow::bail!("Unknown data type: {}", value),
|
||||
};
|
||||
Ok((data_type, next_index))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ThreadsDatabase {
|
||||
executor: BackgroundExecutor,
|
||||
connection: Arc<Mutex<Connection>>,
|
||||
}
|
||||
|
||||
struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
|
||||
|
||||
impl Global for GlobalThreadsDatabase {}
|
||||
|
||||
impl ThreadsDatabase {
|
||||
pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
|
||||
if cx.has_global::<GlobalThreadsDatabase>() {
|
||||
return cx.global::<GlobalThreadsDatabase>().0.clone();
|
||||
}
|
||||
let executor = cx.background_executor().clone();
|
||||
let task = executor
|
||||
.spawn({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
match ThreadsDatabase::new(executor) {
|
||||
Ok(db) => Ok(Arc::new(db)),
|
||||
Err(err) => Err(Arc::new(err)),
|
||||
}
|
||||
}
|
||||
})
|
||||
.shared();
|
||||
|
||||
cx.set_global(GlobalThreadsDatabase(task.clone()));
|
||||
task
|
||||
}
|
||||
|
||||
pub fn new(executor: BackgroundExecutor) -> Result<Self> {
|
||||
let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
|
||||
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
|
||||
} else {
|
||||
let threads_dir = paths::data_dir().join("threads");
|
||||
std::fs::create_dir_all(&threads_dir)?;
|
||||
let sqlite_path = threads_dir.join("threads.db");
|
||||
Connection::open_file(&sqlite_path.to_string_lossy())
|
||||
};
|
||||
|
||||
connection.exec(indoc! {"
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
id TEXT PRIMARY KEY,
|
||||
summary TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
data_type TEXT NOT NULL,
|
||||
data BLOB NOT NULL
|
||||
)
|
||||
"})?()
|
||||
.map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
|
||||
|
||||
let db = Self {
|
||||
executor: executor.clone(),
|
||||
connection: Arc::new(Mutex::new(connection)),
|
||||
};
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
fn save_thread_sync(
|
||||
connection: &Arc<Mutex<Connection>>,
|
||||
id: acp::SessionId,
|
||||
thread: DbThread,
|
||||
) -> Result<()> {
|
||||
const COMPRESSION_LEVEL: i32 = 3;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SerializedThread {
|
||||
#[serde(flatten)]
|
||||
thread: DbThread,
|
||||
version: &'static str,
|
||||
}
|
||||
|
||||
let title = thread.title.to_string();
|
||||
let updated_at = thread.updated_at.to_rfc3339();
|
||||
let json_data = serde_json::to_string(&SerializedThread {
|
||||
thread,
|
||||
version: DbThread::VERSION,
|
||||
})?;
|
||||
|
||||
let connection = connection.lock();
|
||||
|
||||
let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
|
||||
let data_type = DataType::Zstd;
|
||||
let data = compressed;
|
||||
|
||||
let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
|
||||
INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
|
||||
"})?;
|
||||
|
||||
insert((id.0.clone(), title, updated_at, data_type, data))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
|
||||
let connection = self.connection.clone();
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let connection = connection.lock();
|
||||
|
||||
let mut select =
|
||||
connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
|
||||
SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
|
||||
"})?;
|
||||
|
||||
let rows = select(())?;
|
||||
let mut threads = Vec::new();
|
||||
|
||||
for (id, summary, updated_at) in rows {
|
||||
threads.push(DbThreadMetadata {
|
||||
id: acp::SessionId(id),
|
||||
title: summary.into(),
|
||||
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(threads)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
|
||||
let connection = self.connection.clone();
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let connection = connection.lock();
|
||||
let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
|
||||
SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
|
||||
"})?;
|
||||
|
||||
let rows = select(id.0)?;
|
||||
if let Some((data_type, data)) = rows.into_iter().next() {
|
||||
let json_data = match data_type {
|
||||
DataType::Zstd => {
|
||||
let decompressed = zstd::decode_all(&data[..])?;
|
||||
String::from_utf8(decompressed)?
|
||||
}
|
||||
DataType::Json => String::from_utf8(data)?,
|
||||
};
|
||||
let thread = DbThread::from_json(json_data.as_bytes())?;
|
||||
Ok(Some(thread))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
|
||||
let connection = self.connection.clone();
|
||||
|
||||
self.executor
|
||||
.spawn(async move { Self::save_thread_sync(&connection, id, thread) })
|
||||
}
|
||||
|
||||
pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
|
||||
let connection = self.connection.clone();
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let connection = connection.lock();
|
||||
|
||||
let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
|
||||
DELETE FROM threads WHERE id = ?
|
||||
"})?;
|
||||
|
||||
delete(id.0)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use agent::MessageSegment;
|
||||
use agent::context::LoadedContext;
|
||||
use client::Client;
|
||||
use fs::FakeFs;
|
||||
use gpui::AppContext;
|
||||
use gpui::TestAppContext;
|
||||
use http_client::FakeHttpClient;
|
||||
use language_model::Role;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
|
||||
let http_client = FakeHttpClient::with_404_response();
|
||||
let clock = Arc::new(clock::FakeSystemClock::new());
|
||||
let client = Client::new(clock, http_client, cx);
|
||||
agent::init(cx);
|
||||
agent_settings::init(cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
|
||||
// Save a thread using the old agent.
|
||||
let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
|
||||
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_message(
|
||||
Role::User,
|
||||
vec![MessageSegment::Text("Hey!".into())],
|
||||
LoadedContext::default(),
|
||||
vec![],
|
||||
false,
|
||||
cx,
|
||||
);
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Text("How're you doing?".into())],
|
||||
LoadedContext::default(),
|
||||
vec![],
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
thread_store
|
||||
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Open that same thread using the new agent.
|
||||
let db = cx.update(ThreadsDatabase::connect).await.unwrap();
|
||||
let threads = db.list_threads().await.unwrap();
|
||||
assert_eq!(threads.len(), 1);
|
||||
let thread = db
|
||||
.load_thread(threads[0].id.clone())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
|
||||
assert_eq!(
|
||||
thread.messages[1].to_markdown(),
|
||||
"## Assistant\n\nHow're you doing?\n"
|
||||
);
|
||||
}
|
||||
}
|
314
crates/agent2/src/history_store.rs
Normal file
314
crates/agent2/src/history_store.rs
Normal file
|
@ -0,0 +1,314 @@
|
|||
use crate::{DbThreadMetadata, ThreadsDatabase};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_context::SavedContextMetadata;
|
||||
use chrono::{DateTime, Utc};
|
||||
use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
|
||||
use itertools::Itertools;
|
||||
use paths::contexts_dir;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration};
|
||||
use util::ResultExt as _;
|
||||
|
||||
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
|
||||
const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
|
||||
const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
|
||||
|
||||
const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum HistoryEntry {
|
||||
AcpThread(DbThreadMetadata),
|
||||
TextThread(SavedContextMetadata),
|
||||
}
|
||||
|
||||
impl HistoryEntry {
|
||||
pub fn updated_at(&self) -> DateTime<Utc> {
|
||||
match self {
|
||||
HistoryEntry::AcpThread(thread) => thread.updated_at,
|
||||
HistoryEntry::TextThread(context) => context.mtime.to_utc(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> HistoryEntryId {
|
||||
match self {
|
||||
HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()),
|
||||
HistoryEntry::TextThread(context) => HistoryEntryId::TextThread(context.path.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn title(&self) -> &SharedString {
|
||||
match self {
|
||||
HistoryEntry::AcpThread(thread) if thread.title.is_empty() => DEFAULT_TITLE,
|
||||
HistoryEntry::AcpThread(thread) => &thread.title,
|
||||
HistoryEntry::TextThread(context) => &context.title,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic identifier for a history entry.
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub enum HistoryEntryId {
|
||||
AcpThread(acp::SessionId),
|
||||
TextThread(Arc<Path>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
enum SerializedRecentOpen {
|
||||
Thread(String),
|
||||
ContextName(String),
|
||||
/// Old format which stores the full path
|
||||
Context(String),
|
||||
}
|
||||
|
||||
pub struct HistoryStore {
|
||||
threads: Vec<DbThreadMetadata>,
|
||||
context_store: Entity<assistant_context::ContextStore>,
|
||||
recently_opened_entries: VecDeque<HistoryEntryId>,
|
||||
_subscriptions: Vec<gpui::Subscription>,
|
||||
_save_recently_opened_entries_task: Task<()>,
|
||||
}
|
||||
|
||||
impl HistoryStore {
|
||||
pub fn new(
|
||||
context_store: Entity<assistant_context::ContextStore>,
|
||||
initial_recent_entries: impl IntoIterator<Item = HistoryEntryId>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let subscriptions = vec![cx.observe(&context_store, |_, _, cx| cx.notify())];
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let entries = Self::load_recently_opened_entries(cx).await.log_err()?;
|
||||
this.update(cx, |this, _| {
|
||||
this.recently_opened_entries
|
||||
.extend(
|
||||
entries.into_iter().take(
|
||||
MAX_RECENTLY_OPENED_ENTRIES
|
||||
.saturating_sub(this.recently_opened_entries.len()),
|
||||
),
|
||||
);
|
||||
})
|
||||
.ok()
|
||||
})
|
||||
.detach();
|
||||
|
||||
Self {
|
||||
context_store,
|
||||
recently_opened_entries: initial_recent_entries.into_iter().collect(),
|
||||
threads: Vec::default(),
|
||||
_subscriptions: subscriptions,
|
||||
_save_recently_opened_entries_task: Task::ready(()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn delete_thread(
|
||||
&mut self,
|
||||
id: acp::SessionId,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let database_future = ThreadsDatabase::connect(cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let database = database_future.await.map_err(|err| anyhow!(err))?;
|
||||
database.delete_thread(id.clone()).await?;
|
||||
this.update(cx, |this, cx| this.reload(cx))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn delete_text_thread(
|
||||
&mut self,
|
||||
path: Arc<Path>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
self.context_store.update(cx, |context_store, cx| {
|
||||
context_store.delete_local_context(path, cx)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reload(&self, cx: &mut Context<Self>) {
|
||||
let database_future = ThreadsDatabase::connect(cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let threads = database_future
|
||||
.await
|
||||
.map_err(|err| anyhow!(err))?
|
||||
.list_threads()
|
||||
.await?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.threads = threads;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
pub fn entries(&self, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
|
||||
let mut history_entries = Vec::new();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
|
||||
return history_entries;
|
||||
}
|
||||
|
||||
history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread));
|
||||
history_entries.extend(
|
||||
self.context_store
|
||||
.read(cx)
|
||||
.unordered_contexts()
|
||||
.cloned()
|
||||
.map(HistoryEntry::TextThread),
|
||||
);
|
||||
|
||||
history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
|
||||
history_entries
|
||||
}
|
||||
|
||||
pub fn recent_entries(&self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
|
||||
self.entries(cx).into_iter().take(limit).collect()
|
||||
}
|
||||
|
||||
pub fn recently_opened_entries(&self, cx: &App) -> Vec<HistoryEntry> {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let thread_entries = self.threads.iter().flat_map(|thread| {
|
||||
self.recently_opened_entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(index, entry)| match entry {
|
||||
HistoryEntryId::AcpThread(id) if &thread.id == id => {
|
||||
Some((index, HistoryEntry::AcpThread(thread.clone())))
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
});
|
||||
|
||||
let context_entries =
|
||||
self.context_store
|
||||
.read(cx)
|
||||
.unordered_contexts()
|
||||
.flat_map(|context| {
|
||||
self.recently_opened_entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(index, entry)| match entry {
|
||||
HistoryEntryId::TextThread(path) if &context.path == path => {
|
||||
Some((index, HistoryEntry::TextThread(context.clone())))
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
});
|
||||
|
||||
thread_entries
|
||||
.chain(context_entries)
|
||||
// optimization to halt iteration early
|
||||
.take(self.recently_opened_entries.len())
|
||||
.sorted_unstable_by_key(|(index, _)| *index)
|
||||
.map(|(_, entry)| entry)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn save_recently_opened_entries(&mut self, cx: &mut Context<Self>) {
|
||||
let serialized_entries = self
|
||||
.recently_opened_entries
|
||||
.iter()
|
||||
.filter_map(|entry| match entry {
|
||||
HistoryEntryId::TextThread(path) => path.file_name().map(|file| {
|
||||
SerializedRecentOpen::ContextName(file.to_string_lossy().to_string())
|
||||
}),
|
||||
HistoryEntryId::AcpThread(id) => Some(SerializedRecentOpen::Thread(id.to_string())),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| {
|
||||
cx.background_executor()
|
||||
.timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE)
|
||||
.await;
|
||||
cx.background_spawn(async move {
|
||||
let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
|
||||
let content = serde_json::to_string(&serialized_entries)?;
|
||||
std::fs::write(path, content)?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
});
|
||||
}
|
||||
|
||||
fn load_recently_opened_entries(cx: &AsyncApp) -> Task<Result<Vec<HistoryEntryId>>> {
|
||||
cx.background_spawn(async move {
|
||||
let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
|
||||
let contents = match smol::fs::read_to_string(path).await {
|
||||
Ok(it) => it,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e)
|
||||
.context("deserializing persisted agent panel navigation history");
|
||||
}
|
||||
};
|
||||
let entries = serde_json::from_str::<Vec<SerializedRecentOpen>>(&contents)
|
||||
.context("deserializing persisted agent panel navigation history")?
|
||||
.into_iter()
|
||||
.take(MAX_RECENTLY_OPENED_ENTRIES)
|
||||
.flat_map(|entry| match entry {
|
||||
SerializedRecentOpen::Thread(id) => Some(HistoryEntryId::AcpThread(
|
||||
acp::SessionId(id.as_str().into()),
|
||||
)),
|
||||
SerializedRecentOpen::ContextName(file_name) => Some(
|
||||
HistoryEntryId::TextThread(contexts_dir().join(file_name).into()),
|
||||
),
|
||||
SerializedRecentOpen::Context(path) => {
|
||||
Path::new(&path).file_name().map(|file_name| {
|
||||
HistoryEntryId::TextThread(contexts_dir().join(file_name).into())
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
Ok(entries)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn push_recently_opened_entry(&mut self, entry: HistoryEntryId, cx: &mut Context<Self>) {
|
||||
self.recently_opened_entries
|
||||
.retain(|old_entry| old_entry != &entry);
|
||||
self.recently_opened_entries.push_front(entry);
|
||||
self.recently_opened_entries
|
||||
.truncate(MAX_RECENTLY_OPENED_ENTRIES);
|
||||
self.save_recently_opened_entries(cx);
|
||||
}
|
||||
|
||||
pub fn remove_recently_opened_thread(&mut self, id: acp::SessionId, cx: &mut Context<Self>) {
|
||||
self.recently_opened_entries.retain(|entry| match entry {
|
||||
HistoryEntryId::AcpThread(thread_id) if thread_id == &id => false,
|
||||
_ => true,
|
||||
});
|
||||
self.save_recently_opened_entries(cx);
|
||||
}
|
||||
|
||||
pub fn replace_recently_opened_text_thread(
|
||||
&mut self,
|
||||
old_path: &Path,
|
||||
new_path: &Arc<Path>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
for entry in &mut self.recently_opened_entries {
|
||||
match entry {
|
||||
HistoryEntryId::TextThread(path) if path.as_ref() == old_path => {
|
||||
*entry = HistoryEntryId::TextThread(new_path.clone());
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
self.save_recently_opened_entries(cx);
|
||||
}
|
||||
|
||||
pub fn remove_recently_opened_entry(&mut self, entry: &HistoryEntryId, cx: &mut Context<Self>) {
|
||||
self.recently_opened_entries
|
||||
.retain(|old_entry| old_entry != entry);
|
||||
self.save_recently_opened_entries(cx);
|
||||
}
|
||||
}
|
|
@ -7,16 +7,17 @@ use gpui::{App, Entity, Task};
|
|||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
|
||||
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
|
||||
use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentServer {
|
||||
fs: Arc<dyn Fs>,
|
||||
history: Entity<HistoryStore>,
|
||||
}
|
||||
|
||||
impl NativeAgentServer {
|
||||
pub fn new(fs: Arc<dyn Fs>) -> Self {
|
||||
Self { fs }
|
||||
pub fn new(fs: Arc<dyn Fs>, history: Entity<HistoryStore>) -> Self {
|
||||
Self { fs, history }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,6 +51,7 @@ impl AgentServer for NativeAgentServer {
|
|||
);
|
||||
let project = project.clone();
|
||||
let fs = self.fs.clone();
|
||||
let history = self.history.clone();
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Creating templates for native agent");
|
||||
|
@ -57,7 +59,8 @@ impl AgentServer for NativeAgentServer {
|
|||
let prompt_store = prompt_store.await?;
|
||||
|
||||
log::debug!("Creating native agent entity");
|
||||
let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
|
||||
let agent =
|
||||
NativeAgent::new(project, history, templates, Some(prompt_store), fs, cx).await?;
|
||||
|
||||
// Create the connection wrapper
|
||||
let connection = NativeAgentConnection(agent);
|
||||
|
|
|
@ -1273,10 +1273,13 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
|||
fake_fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
|
||||
let cwd = Path::new("/test");
|
||||
let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
|
||||
let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
|
||||
|
||||
// Create agent and connection
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
history_store,
|
||||
templates.clone(),
|
||||
None,
|
||||
fake_fs.clone(),
|
||||
|
@ -1756,7 +1759,6 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
|||
action_log,
|
||||
templates,
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
||||
use crate::{
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
|
||||
DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
|
||||
ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
|
||||
Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
|
||||
};
|
||||
use acp_thread::{MentionUri, UserMessageId};
|
||||
use action_log::ActionLog;
|
||||
use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
|
||||
|
@ -17,13 +22,13 @@ use futures::{
|
|||
stream::FuturesUnordered,
|
||||
};
|
||||
use git::repository::DiffType;
|
||||
use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
TokenUsage,
|
||||
LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
|
||||
LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
|
||||
};
|
||||
use project::{
|
||||
Project,
|
||||
|
@ -516,8 +521,8 @@ pub struct Thread {
|
|||
templates: Arc<Templates>,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
summarization_model: Option<Arc<dyn LanguageModel>>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
pub(crate) project: Entity<Project>,
|
||||
pub(crate) action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
|
@ -528,7 +533,6 @@ impl Thread {
|
|||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
model: Option<Arc<dyn LanguageModel>>,
|
||||
summarization_model: Option<Arc<dyn LanguageModel>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
|
@ -557,7 +561,7 @@ impl Thread {
|
|||
project_context,
|
||||
templates,
|
||||
model,
|
||||
summarization_model,
|
||||
summarization_model: None,
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
|
@ -652,6 +656,88 @@ impl Thread {
|
|||
);
|
||||
}
|
||||
|
||||
pub fn from_db(
|
||||
id: acp::SessionId,
|
||||
db_thread: DbThread,
|
||||
project: Entity<Project>,
|
||||
project_context: Entity<ProjectContext>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = db_thread
|
||||
.profile
|
||||
.unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
|
||||
let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
db_thread
|
||||
.model
|
||||
.and_then(|model| {
|
||||
let model = SelectedModel {
|
||||
provider: model.provider.clone().into(),
|
||||
model: model.model.clone().into(),
|
||||
};
|
||||
registry.select_model(&model, cx)
|
||||
})
|
||||
.or_else(|| registry.default_model())
|
||||
.map(|model| model.model)
|
||||
});
|
||||
|
||||
Self {
|
||||
id,
|
||||
prompt_id: PromptId::new(),
|
||||
title: if db_thread.title.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(db_thread.title.clone())
|
||||
},
|
||||
summary: db_thread.summary,
|
||||
messages: db_thread.messages,
|
||||
completion_mode: db_thread.completion_mode.unwrap_or_default(),
|
||||
running_turn: None,
|
||||
pending_message: None,
|
||||
tools: BTreeMap::default(),
|
||||
tool_use_limit_reached: false,
|
||||
request_token_usage: db_thread.request_token_usage.clone(),
|
||||
cumulative_token_usage: db_thread.cumulative_token_usage,
|
||||
initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
|
||||
context_server_registry,
|
||||
profile_id,
|
||||
project_context,
|
||||
templates,
|
||||
model,
|
||||
summarization_model: None,
|
||||
project,
|
||||
action_log,
|
||||
updated_at: db_thread.updated_at,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_db(&self, cx: &App) -> Task<DbThread> {
|
||||
let initial_project_snapshot = self.initial_project_snapshot.clone();
|
||||
let mut thread = DbThread {
|
||||
title: self.title.clone().unwrap_or_default(),
|
||||
messages: self.messages.clone(),
|
||||
updated_at: self.updated_at,
|
||||
summary: self.summary.clone(),
|
||||
initial_project_snapshot: None,
|
||||
cumulative_token_usage: self.cumulative_token_usage,
|
||||
request_token_usage: self.request_token_usage.clone(),
|
||||
model: self.model.as_ref().map(|model| DbLanguageModel {
|
||||
provider: model.provider_id().to_string(),
|
||||
model: model.name().0.to_string(),
|
||||
}),
|
||||
completion_mode: Some(self.completion_mode),
|
||||
profile: Some(self.profile_id.clone()),
|
||||
};
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let initial_project_snapshot = initial_project_snapshot.await;
|
||||
thread.initial_project_snapshot = initial_project_snapshot;
|
||||
thread
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a snapshot of the current project state including git information and unsaved buffers.
|
||||
fn project_snapshot(
|
||||
project: Entity<Project>,
|
||||
|
@ -816,6 +902,32 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
|
||||
let language_registry = self.project.read(cx).languages().clone();
|
||||
self.add_tool(CopyPathTool::new(self.project.clone()));
|
||||
self.add_tool(CreateDirectoryTool::new(self.project.clone()));
|
||||
self.add_tool(DeletePathTool::new(
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
));
|
||||
self.add_tool(DiagnosticsTool::new(self.project.clone()));
|
||||
self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
|
||||
self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
|
||||
self.add_tool(FindPathTool::new(self.project.clone()));
|
||||
self.add_tool(GrepTool::new(self.project.clone()));
|
||||
self.add_tool(ListDirectoryTool::new(self.project.clone()));
|
||||
self.add_tool(MovePathTool::new(self.project.clone()));
|
||||
self.add_tool(NowTool);
|
||||
self.add_tool(OpenTool::new(self.project.clone()));
|
||||
self.add_tool(ReadFileTool::new(
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
));
|
||||
self.add_tool(TerminalTool::new(self.project.clone(), cx));
|
||||
self.add_tool(ThinkingTool);
|
||||
self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
|
||||
}
|
||||
|
||||
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
||||
self.tools.insert(tool.name(), tool.erase());
|
||||
}
|
||||
|
|
|
@ -554,7 +554,6 @@ mod tests {
|
|||
action_log,
|
||||
Templates::new(),
|
||||
Some(model),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -756,7 +755,6 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -899,7 +897,6 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -1029,7 +1026,6 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -1168,7 +1164,6 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -1279,7 +1274,6 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -1362,7 +1356,6 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -1448,7 +1441,6 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
@ -1531,7 +1523,6 @@ mod tests {
|
|||
action_log.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue