Wire up history completely

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Conrad Irwin 2025-08-18 10:40:15 -06:00
parent d83210d978
commit 4b1a48e4de
7 changed files with 93 additions and 98 deletions

View file

@ -2,7 +2,6 @@ use crate::{AcpThread, AcpThreadMetadata};
use agent_client_protocol::{self as acp}; use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;
use collections::IndexMap; use collections::IndexMap;
use futures::channel::mpsc::UnboundedReceiver;
use gpui::{Entity, SharedString, Task}; use gpui::{Entity, SharedString, Task};
use project::Project; use project::Project;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -27,6 +26,8 @@ pub trait AgentConnection {
cx: &mut App, cx: &mut App,
) -> Task<Result<Entity<AcpThread>>>; ) -> Task<Result<Entity<AcpThread>>>;
// todo!(expose a history trait, and include list_threads and load_thread)
// todo!(write a test)
fn list_threads( fn list_threads(
&self, &self,
_cx: &mut App, _cx: &mut App,

View file

@ -5,16 +5,15 @@ use crate::{
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
UserMessageContent, WebSearchTool, templates::Templates, UserMessageContent, WebSearchTool, templates::Templates,
}; };
use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id}; use crate::{ThreadsDatabase, generate_session_id};
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::AgentSettings; use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap}; use collections::{HashSet, IndexMap};
use fs::Fs; use fs::Fs;
use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; use futures::channel::mpsc;
use futures::future::Shared; use futures::{StreamExt, future};
use futures::{SinkExt, StreamExt, future};
use gpui::{ use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
}; };
@ -30,6 +29,7 @@ use std::collections::HashMap;
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use util::ResultExt; use util::ResultExt;
const RULES_FILE_NAMES: [&'static str; 9] = [ const RULES_FILE_NAMES: [&'static str; 9] = [
@ -174,7 +174,7 @@ pub struct NativeAgent {
prompt_store: Option<Entity<PromptStore>>, prompt_store: Option<Entity<PromptStore>>,
thread_database: Arc<ThreadsDatabase>, thread_database: Arc<ThreadsDatabase>,
history: watch::Sender<Option<Vec<AcpThreadMetadata>>>, history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
load_history: Task<Result<()>>, load_history: Task<()>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
@ -212,7 +212,7 @@ impl NativeAgent {
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
watch::channel(()); watch::channel(());
let this = Self { let mut this = Self {
sessions: HashMap::new(), sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)), project_context: Rc::new(RefCell::new(project_context)),
project_context_needs_refresh: project_context_needs_refresh_tx, project_context_needs_refresh: project_context_needs_refresh_tx,
@ -229,7 +229,7 @@ impl NativeAgent {
prompt_store, prompt_store,
fs, fs,
history: watch::channel(None).0, history: watch::channel(None).0,
load_history: Task::ready(Ok(())), load_history: Task::ready(()),
_subscriptions: subscriptions, _subscriptions: subscriptions,
}; };
this.reload_history(cx); this.reload_history(cx);
@ -249,7 +249,7 @@ impl NativeAgent {
Session { Session {
thread: thread.clone(), thread: thread.clone(),
acp_thread: acp_thread.downgrade(), acp_thread: acp_thread.downgrade(),
save_task: Task::ready(()), save_task: Task::ready(Ok(())),
_subscriptions: vec![ _subscriptions: vec![
cx.observe_release(&acp_thread, |this, acp_thread, _cx| { cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id()); this.sessions.remove(acp_thread.session_id());
@ -280,24 +280,30 @@ impl NativeAgent {
} }
fn reload_history(&mut self, cx: &mut Context<Self>) { fn reload_history(&mut self, cx: &mut Context<Self>) {
dbg!("");
let thread_database = self.thread_database.clone(); let thread_database = self.thread_database.clone();
self.load_history = cx.spawn(async move |this, cx| { self.load_history = cx.spawn(async move |this, cx| {
let results = cx let results = cx
.background_spawn(async move { .background_spawn(async move {
let results = thread_database.list_threads().await?; let results = thread_database.list_threads().await?;
Ok(results dbg!(&results);
.into_iter() anyhow::Ok(
.map(|thread| AcpThreadMetadata { results
agent: NATIVE_AGENT_SERVER_NAME.clone(), .into_iter()
id: thread.id.into(), .map(|thread| AcpThreadMetadata {
title: thread.title, agent: NATIVE_AGENT_SERVER_NAME.clone(),
updated_at: thread.updated_at, id: thread.id.into(),
}) title: thread.title,
.collect()) updated_at: thread.updated_at,
})
.collect(),
)
}) })
.await?; .await;
this.update(cx, |this, cx| this.history.send(Some(results)))?; if let Some(results) = results.log_err() {
anyhow::Ok(()) this.update(cx, |this, _| this.history.send(Some(results)))
.ok();
}
}); });
} }
@ -509,10 +515,10 @@ impl NativeAgent {
) { ) {
self.models.refresh_list(cx); self.models.refresh_list(cx);
for session in self.sessions.values_mut() { for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, _| { session.thread.update(cx, |thread, cx| {
let model_id = LanguageModels::model_id(&thread.model()); let model_id = LanguageModels::model_id(&thread.model());
if let Some(model) = self.models.model_from_id(&model_id) { if let Some(model) = self.models.model_from_id(&model_id) {
thread.set_model(model.clone()); thread.set_model(model.clone(), cx);
} }
}); });
} }
@ -715,8 +721,8 @@ impl AgentModelSelector for NativeAgentConnection {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
}; };
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, cx| {
thread.set_model(model.clone()); thread.set_model(model.clone(), cx);
}); });
update_settings_file::<AgentSettings>( update_settings_file::<AgentSettings>(
@ -867,12 +873,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
session_id: acp::SessionId, session_id: acp::SessionId,
cx: &mut App, cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> { ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let thread_id = ThreadId::from(session_id.clone());
let database = self.0.update(cx, |this, _| this.thread_database.clone()); let database = self.0.update(cx, |this, _| this.thread_database.clone());
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
let database = database.await.map_err(|e| anyhow!(e))?;
let db_thread = database let db_thread = database
.load_thread(thread_id.clone()) .load_thread(session_id.clone())
.await? .await?
.context("no such thread found")?; .context("no such thread found")?;
@ -915,7 +919,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let mut thread = Thread::from_db( let mut thread = Thread::from_db(
thread_id, session_id,
db_thread, db_thread,
project.clone(), project.clone(),
agent.project_context.clone(), agent.project_context.clone(),
@ -934,7 +938,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Store the session // Store the session
agent.update(cx, |agent, cx| { agent.update(cx, |agent, cx| {
agent.insert_session(session_id, thread, acp_thread, cx) agent.insert_session(thread.clone(), acp_thread.clone(), cx)
})?; })?;
let events = thread.update(cx, |thread, cx| thread.replay(cx))?; let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
@ -995,7 +999,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Cancelling on session: {}", session_id); log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| { self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) { if let Some(agent) = agent.sessions.get(session_id) {
agent.thread.update(cx, |thread, _cx| thread.cancel()); agent.thread.update(cx, |thread, cx| thread.cancel(cx));
} }
}); });
} }
@ -1022,7 +1026,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> { fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) Task::ready(
self.0
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
)
} }
} }

View file

@ -1,4 +1,4 @@
use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent}; use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
use agent::thread_store; use agent::thread_store;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, CompletionMode}; use agent_settings::{AgentProfileId, CompletionMode};
@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbThreadMetadata { pub struct DbThreadMetadata {
pub id: ThreadId, pub id: acp::SessionId,
#[serde(alias = "summary")] #[serde(alias = "summary")]
pub title: SharedString, pub title: SharedString,
pub updated_at: DateTime<Utc>, pub updated_at: DateTime<Utc>,
@ -323,7 +323,7 @@ impl ThreadsDatabase {
for (id, summary, updated_at) in rows { for (id, summary, updated_at) in rows {
threads.push(DbThreadMetadata { threads.push(DbThreadMetadata {
id: ThreadId(id), id: acp::SessionId(id),
title: summary.into(), title: summary.into(),
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
}); });
@ -333,7 +333,7 @@ impl ThreadsDatabase {
}) })
} }
pub fn load_thread(&self, id: ThreadId) -> Task<Result<Option<DbThread>>> { pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
let connection = self.connection.clone(); let connection = self.connection.clone();
self.executor.spawn(async move { self.executor.spawn(async move {

View file

@ -1,17 +1,13 @@
use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName}; use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
use agent::{ThreadId, thread_store::ThreadStore};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use assistant_context::SavedContextMetadata; use assistant_context::SavedContextMetadata;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::HashMap; use collections::HashMap;
use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*}; use gpui::{SharedString, Task, prelude::*};
use itertools::Itertools;
use paths::contexts_dir;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use smol::stream::StreamExt; use smol::stream::StreamExt;
use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration}; use std::{path::Path, sync::Arc, time::Duration};
use util::ResultExt as _;
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json"; const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
@ -64,16 +60,16 @@ enum SerializedRecentOpen {
} }
pub struct AgentHistory { pub struct AgentHistory {
entries: HashMap<acp::SessionId, AcpThreadMetadata>, entries: watch::Receiver<Option<Vec<AcpThreadMetadata>>>,
_task: Task<Result<()>>, _task: Task<()>,
} }
pub struct HistoryStore { pub struct HistoryStore {
agents: HashMap<AgentServerName, AgentHistory>, agents: HashMap<AgentServerName, AgentHistory>, // todo!() text threads
} }
impl HistoryStore { impl HistoryStore {
pub fn new(cx: &mut Context<Self>) -> Self { pub fn new(_cx: &mut Context<Self>) -> Self {
Self { Self {
agents: HashMap::default(), agents: HashMap::default(),
} }
@ -88,33 +84,18 @@ impl HistoryStore {
let Some(mut history) = connection.list_threads(cx) else { let Some(mut history) = connection.list_threads(cx) else {
return; return;
}; };
let task = cx.spawn(async move |this, cx| { let history = AgentHistory {
while let Some(updated_history) = history.next().await { entries: history.clone(),
dbg!(&updated_history); _task: cx.spawn(async move |this, cx| {
this.update(cx, |this, cx| { while history.changed().await.is_ok() {
for entry in updated_history { this.update(cx, |_, cx| cx.notify()).ok();
let agent = this }
.agents }),
.get_mut(&entry.agent) };
.context("agent not found")?; self.agents.insert(agent_name.clone(), history);
agent.entries.insert(entry.id.clone(), entry);
}
cx.notify();
anyhow::Ok(())
})??
}
Ok(())
});
self.agents.insert(
agent_name,
AgentHistory {
entries: Default::default(),
_task: task,
},
);
} }
pub fn entries(&self, cx: &mut Context<Self>) -> Vec<HistoryEntry> { pub fn entries(&mut self, _cx: &mut Context<Self>) -> Vec<HistoryEntry> {
let mut history_entries = Vec::new(); let mut history_entries = Vec::new();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -124,9 +105,8 @@ impl HistoryStore {
history_entries.extend( history_entries.extend(
self.agents self.agents
.values() .values_mut()
.flat_map(|agent| agent.entries.values()) .flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?")
.cloned()
.map(HistoryEntry::Thread), .map(HistoryEntry::Thread),
); );
// todo!() include the text threads in here. // todo!() include the text threads in here.
@ -135,7 +115,7 @@ impl HistoryStore {
history_entries history_entries
} }
pub fn recent_entries(&self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> { pub fn recent_entries(&mut self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
self.entries(cx).into_iter().take(limit).collect() self.entries(cx).into_iter().take(limit).collect()
} }
} }

View file

@ -938,7 +938,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even // Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running. // if one of the tools is still running.
thread.update(cx, |thread, _cx| thread.cancel()); thread.update(cx, |thread, cx| thread.cancel(cx));
let events = events.collect::<Vec<_>>().await; let events = events.collect::<Vec<_>>().await;
let last_event = events.last(); let last_event = events.last();
assert!( assert!(
@ -1113,7 +1113,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
}); });
thread thread
.update(cx, |thread, _cx| thread.truncate(message_id)) .update(cx, |thread, cx| thread.truncate(message_id, cx))
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {

View file

@ -802,16 +802,18 @@ impl Thread {
&self.model &self.model
} }
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) { pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
self.model = model; self.model = model;
cx.notify()
} }
pub fn completion_mode(&self) -> CompletionMode { pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode self.completion_mode
} }
pub fn set_completion_mode(&mut self, mode: CompletionMode) { pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
self.completion_mode = mode; self.completion_mode = mode;
cx.notify()
} }
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
@ -839,21 +841,22 @@ impl Thread {
self.profile_id = profile_id; self.profile_id = profile_id;
} }
pub fn cancel(&mut self) { pub fn cancel(&mut self, cx: &mut Context<Self>) {
if let Some(running_turn) = self.running_turn.take() { if let Some(running_turn) = self.running_turn.take() {
running_turn.cancel(); running_turn.cancel();
} }
self.flush_pending_message(); self.flush_pending_message(cx);
} }
pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
self.cancel(); self.cancel(cx);
let Some(position) = self.messages.iter().position( let Some(position) = self.messages.iter().position(
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
) else { ) else {
return Err(anyhow!("Message not found")); return Err(anyhow!("Message not found"));
}; };
self.messages.truncate(position); self.messages.truncate(position);
cx.notify();
Ok(()) Ok(())
} }
@ -900,7 +903,7 @@ impl Thread {
} }
fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> { fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
self.cancel(); self.cancel(cx);
let model = self.model.clone(); let model = self.model.clone();
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>(); let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
@ -938,8 +941,8 @@ impl Thread {
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason); event_stream.send_stop(reason);
if reason == StopReason::Refusal { if reason == StopReason::Refusal {
this.update(cx, |this, _cx| { this.update(cx, |this, cx| {
this.flush_pending_message(); this.flush_pending_message(cx);
this.messages.truncate(message_ix); this.messages.truncate(message_ix);
})?; })?;
return Ok(()); return Ok(());
@ -991,7 +994,7 @@ impl Thread {
log::info!("No tool uses found, completing turn"); log::info!("No tool uses found, completing turn");
return Ok(()); return Ok(());
} else { } else {
this.update(cx, |this, _| this.flush_pending_message())?; this.update(cx, |this, cx| this.flush_pending_message(cx))?;
completion_intent = CompletionIntent::ToolResults; completion_intent = CompletionIntent::ToolResults;
} }
} }
@ -1005,8 +1008,8 @@ impl Thread {
log::info!("Turn execution completed successfully"); log::info!("Turn execution completed successfully");
} }
this.update(cx, |this, _| { this.update(cx, |this, cx| {
this.flush_pending_message(); this.flush_pending_message(cx);
this.running_turn.take(); this.running_turn.take();
}) })
.ok(); .ok();
@ -1046,7 +1049,7 @@ impl Thread {
match event { match event {
StartMessage { .. } => { StartMessage { .. } => {
self.flush_pending_message(); self.flush_pending_message(cx);
self.pending_message = Some(AgentMessage::default()); self.pending_message = Some(AgentMessage::default());
} }
Text(new_text) => self.handle_text_event(new_text, event_stream, cx), Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
@ -1255,7 +1258,7 @@ impl Thread {
self.pending_message.get_or_insert_default() self.pending_message.get_or_insert_default()
} }
fn flush_pending_message(&mut self) { fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
let Some(mut message) = self.pending_message.take() else { let Some(mut message) = self.pending_message.take() else {
return; return;
}; };
@ -1280,6 +1283,7 @@ impl Thread {
} }
self.messages.push(Message::Agent(message)); self.messages.push(Message::Agent(message));
cx.notify()
} }
pub(crate) fn build_completion_request( pub(crate) fn build_completion_request(

View file

@ -2487,12 +2487,15 @@ impl AcpThreadView {
return; return;
}; };
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, cx| {
let current_mode = thread.completion_mode(); let current_mode = thread.completion_mode();
thread.set_completion_mode(match current_mode { thread.set_completion_mode(
CompletionMode::Burn => CompletionMode::Normal, match current_mode {
CompletionMode::Normal => CompletionMode::Burn, CompletionMode::Burn => CompletionMode::Normal,
}); CompletionMode::Normal => CompletionMode::Burn,
},
cx,
);
}); });
} }
@ -3274,8 +3277,8 @@ impl AcpThreadView {
.tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use.")) .tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
.on_click({ .on_click({
cx.listener(move |this, _, _window, cx| { cx.listener(move |this, _, _window, cx| {
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, cx| {
thread.set_completion_mode(CompletionMode::Burn); thread.set_completion_mode(CompletionMode::Burn, cx);
}); });
this.resume_chat(cx); this.resume_chat(cx);
}) })