Wire up history completely

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Conrad Irwin 2025-08-18 10:40:15 -06:00
parent d83210d978
commit 4b1a48e4de
7 changed files with 93 additions and 98 deletions

View file

@ -2,7 +2,6 @@ use crate::{AcpThread, AcpThreadMetadata};
use agent_client_protocol::{self as acp};
use anyhow::Result;
use collections::IndexMap;
use futures::channel::mpsc::UnboundedReceiver;
use gpui::{Entity, SharedString, Task};
use project::Project;
use serde::{Deserialize, Serialize};
@ -27,6 +26,8 @@ pub trait AgentConnection {
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>>;
// todo!(expose a history trait, and include list_threads and load_thread)
// todo!(write a test)
fn list_threads(
&self,
_cx: &mut App,

View file

@ -5,16 +5,15 @@ use crate::{
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
UserMessageContent, WebSearchTool, templates::Templates,
};
use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id};
use crate::{ThreadsDatabase, generate_session_id};
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
use futures::future::Shared;
use futures::{SinkExt, StreamExt, future};
use futures::channel::mpsc;
use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
@ -30,6 +29,7 @@ use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
use util::ResultExt;
const RULES_FILE_NAMES: [&'static str; 9] = [
@ -174,7 +174,7 @@ pub struct NativeAgent {
prompt_store: Option<Entity<PromptStore>>,
thread_database: Arc<ThreadsDatabase>,
history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
load_history: Task<Result<()>>,
load_history: Task<()>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
}
@ -212,7 +212,7 @@ impl NativeAgent {
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
watch::channel(());
let this = Self {
let mut this = Self {
sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)),
project_context_needs_refresh: project_context_needs_refresh_tx,
@ -229,7 +229,7 @@ impl NativeAgent {
prompt_store,
fs,
history: watch::channel(None).0,
load_history: Task::ready(Ok(())),
load_history: Task::ready(()),
_subscriptions: subscriptions,
};
this.reload_history(cx);
@ -249,7 +249,7 @@ impl NativeAgent {
Session {
thread: thread.clone(),
acp_thread: acp_thread.downgrade(),
save_task: Task::ready(()),
save_task: Task::ready(Ok(())),
_subscriptions: vec![
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
@ -280,24 +280,30 @@ impl NativeAgent {
}
fn reload_history(&mut self, cx: &mut Context<Self>) {
dbg!("");
let thread_database = self.thread_database.clone();
self.load_history = cx.spawn(async move |this, cx| {
let results = cx
.background_spawn(async move {
let results = thread_database.list_threads().await?;
Ok(results
.into_iter()
.map(|thread| AcpThreadMetadata {
agent: NATIVE_AGENT_SERVER_NAME.clone(),
id: thread.id.into(),
title: thread.title,
updated_at: thread.updated_at,
})
.collect())
dbg!(&results);
anyhow::Ok(
results
.into_iter()
.map(|thread| AcpThreadMetadata {
agent: NATIVE_AGENT_SERVER_NAME.clone(),
id: thread.id.into(),
title: thread.title,
updated_at: thread.updated_at,
})
.collect(),
)
})
.await?;
this.update(cx, |this, cx| this.history.send(Some(results)))?;
anyhow::Ok(())
.await;
if let Some(results) = results.log_err() {
this.update(cx, |this, _| this.history.send(Some(results)))
.ok();
}
});
}
@ -509,10 +515,10 @@ impl NativeAgent {
) {
self.models.refresh_list(cx);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, _| {
session.thread.update(cx, |thread, cx| {
let model_id = LanguageModels::model_id(&thread.model());
if let Some(model) = self.models.model_from_id(&model_id) {
thread.set_model(model.clone());
thread.set_model(model.clone(), cx);
}
});
}
@ -715,8 +721,8 @@ impl AgentModelSelector for NativeAgentConnection {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
thread.update(cx, |thread, _cx| {
thread.set_model(model.clone());
thread.update(cx, |thread, cx| {
thread.set_model(model.clone(), cx);
});
update_settings_file::<AgentSettings>(
@ -867,12 +873,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
session_id: acp::SessionId,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let thread_id = ThreadId::from(session_id.clone());
let database = self.0.update(cx, |this, _| this.thread_database.clone());
cx.spawn(async move |cx| {
let database = database.await.map_err(|e| anyhow!(e))?;
let db_thread = database
.load_thread(thread_id.clone())
.load_thread(session_id.clone())
.await?
.context("no such thread found")?;
@ -915,7 +919,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let thread = cx.new(|cx| {
let mut thread = Thread::from_db(
thread_id,
session_id,
db_thread,
project.clone(),
agent.project_context.clone(),
@ -934,7 +938,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Store the session
agent.update(cx, |agent, cx| {
agent.insert_session(session_id, thread, acp_thread, cx)
agent.insert_session(thread.clone(), acp_thread.clone(), cx)
})?;
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
@ -995,7 +999,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) {
agent.thread.update(cx, |thread, _cx| thread.cancel());
agent.thread.update(cx, |thread, cx| thread.cancel(cx));
}
});
}
@ -1022,7 +1026,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
Task::ready(
self.0
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
)
}
}

View file

@ -1,4 +1,4 @@
use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent};
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
use agent::thread_store;
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, CompletionMode};
@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbThreadMetadata {
pub id: ThreadId,
pub id: acp::SessionId,
#[serde(alias = "summary")]
pub title: SharedString,
pub updated_at: DateTime<Utc>,
@ -323,7 +323,7 @@ impl ThreadsDatabase {
for (id, summary, updated_at) in rows {
threads.push(DbThreadMetadata {
id: ThreadId(id),
id: acp::SessionId(id),
title: summary.into(),
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
});
@ -333,7 +333,7 @@ impl ThreadsDatabase {
})
}
pub fn load_thread(&self, id: ThreadId) -> Task<Result<Option<DbThread>>> {
pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
let connection = self.connection.clone();
self.executor.spawn(async move {

View file

@ -1,17 +1,13 @@
use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
use agent::{ThreadId, thread_store::ThreadStore};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
use assistant_context::SavedContextMetadata;
use chrono::{DateTime, Utc};
use collections::HashMap;
use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
use itertools::Itertools;
use paths::contexts_dir;
use gpui::{SharedString, Task, prelude::*};
use serde::{Deserialize, Serialize};
use smol::stream::StreamExt;
use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration};
use util::ResultExt as _;
use std::{path::Path, sync::Arc, time::Duration};
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
@ -64,16 +60,16 @@ enum SerializedRecentOpen {
}
pub struct AgentHistory {
entries: HashMap<acp::SessionId, AcpThreadMetadata>,
_task: Task<Result<()>>,
entries: watch::Receiver<Option<Vec<AcpThreadMetadata>>>,
_task: Task<()>,
}
pub struct HistoryStore {
agents: HashMap<AgentServerName, AgentHistory>,
agents: HashMap<AgentServerName, AgentHistory>, // todo!() text threads
}
impl HistoryStore {
pub fn new(cx: &mut Context<Self>) -> Self {
pub fn new(_cx: &mut Context<Self>) -> Self {
Self {
agents: HashMap::default(),
}
@ -88,33 +84,18 @@ impl HistoryStore {
let Some(mut history) = connection.list_threads(cx) else {
return;
};
let task = cx.spawn(async move |this, cx| {
while let Some(updated_history) = history.next().await {
dbg!(&updated_history);
this.update(cx, |this, cx| {
for entry in updated_history {
let agent = this
.agents
.get_mut(&entry.agent)
.context("agent not found")?;
agent.entries.insert(entry.id.clone(), entry);
}
cx.notify();
anyhow::Ok(())
})??
}
Ok(())
});
self.agents.insert(
agent_name,
AgentHistory {
entries: Default::default(),
_task: task,
},
);
let history = AgentHistory {
entries: history.clone(),
_task: cx.spawn(async move |this, cx| {
while history.changed().await.is_ok() {
this.update(cx, |_, cx| cx.notify()).ok();
}
}),
};
self.agents.insert(agent_name.clone(), history);
}
pub fn entries(&self, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
pub fn entries(&mut self, _cx: &mut Context<Self>) -> Vec<HistoryEntry> {
let mut history_entries = Vec::new();
#[cfg(debug_assertions)]
@ -124,9 +105,8 @@ impl HistoryStore {
history_entries.extend(
self.agents
.values()
.flat_map(|agent| agent.entries.values())
.cloned()
.values_mut()
.flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?")
.map(HistoryEntry::Thread),
);
// todo!() include the text threads in here.
@ -135,7 +115,7 @@ impl HistoryStore {
history_entries
}
pub fn recent_entries(&self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
pub fn recent_entries(&mut self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
self.entries(cx).into_iter().take(limit).collect()
}
}

View file

@ -938,7 +938,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running.
thread.update(cx, |thread, _cx| thread.cancel());
thread.update(cx, |thread, cx| thread.cancel(cx));
let events = events.collect::<Vec<_>>().await;
let last_event = events.last();
assert!(
@ -1113,7 +1113,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
thread
.update(cx, |thread, _cx| thread.truncate(message_id))
.update(cx, |thread, cx| thread.truncate(message_id, cx))
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {

View file

@ -802,16 +802,18 @@ impl Thread {
&self.model
}
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
self.model = model;
cx.notify()
}
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
self.completion_mode = mode;
cx.notify()
}
#[cfg(any(test, feature = "test-support"))]
@ -839,21 +841,22 @@ impl Thread {
self.profile_id = profile_id;
}
pub fn cancel(&mut self) {
pub fn cancel(&mut self, cx: &mut Context<Self>) {
if let Some(running_turn) = self.running_turn.take() {
running_turn.cancel();
}
self.flush_pending_message();
self.flush_pending_message(cx);
}
pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
self.cancel();
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
self.cancel(cx);
let Some(position) = self.messages.iter().position(
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
) else {
return Err(anyhow!("Message not found"));
};
self.messages.truncate(position);
cx.notify();
Ok(())
}
@ -900,7 +903,7 @@ impl Thread {
}
fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
self.cancel();
self.cancel(cx);
let model = self.model.clone();
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
@ -938,8 +941,8 @@ impl Thread {
LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
this.update(cx, |this, _cx| {
this.flush_pending_message();
this.update(cx, |this, cx| {
this.flush_pending_message(cx);
this.messages.truncate(message_ix);
})?;
return Ok(());
@ -991,7 +994,7 @@ impl Thread {
log::info!("No tool uses found, completing turn");
return Ok(());
} else {
this.update(cx, |this, _| this.flush_pending_message())?;
this.update(cx, |this, cx| this.flush_pending_message(cx))?;
completion_intent = CompletionIntent::ToolResults;
}
}
@ -1005,8 +1008,8 @@ impl Thread {
log::info!("Turn execution completed successfully");
}
this.update(cx, |this, _| {
this.flush_pending_message();
this.update(cx, |this, cx| {
this.flush_pending_message(cx);
this.running_turn.take();
})
.ok();
@ -1046,7 +1049,7 @@ impl Thread {
match event {
StartMessage { .. } => {
self.flush_pending_message();
self.flush_pending_message(cx);
self.pending_message = Some(AgentMessage::default());
}
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
@ -1255,7 +1258,7 @@ impl Thread {
self.pending_message.get_or_insert_default()
}
fn flush_pending_message(&mut self) {
fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
let Some(mut message) = self.pending_message.take() else {
return;
};
@ -1280,6 +1283,7 @@ impl Thread {
}
self.messages.push(Message::Agent(message));
cx.notify()
}
pub(crate) fn build_completion_request(

View file

@ -2487,12 +2487,15 @@ impl AcpThreadView {
return;
};
thread.update(cx, |thread, _cx| {
thread.update(cx, |thread, cx| {
let current_mode = thread.completion_mode();
thread.set_completion_mode(match current_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
});
thread.set_completion_mode(
match current_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
},
cx,
);
});
}
@ -3274,8 +3277,8 @@ impl AcpThreadView {
.tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
.on_click({
cx.listener(move |this, _, _window, cx| {
thread.update(cx, |thread, _cx| {
thread.set_completion_mode(CompletionMode::Burn);
thread.update(cx, |thread, cx| {
thread.set_completion_mode(CompletionMode::Burn, cx);
});
this.resume_chat(cx);
})