diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 398222a831..94b5fe015a 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -2,7 +2,6 @@ use crate::{AcpThread, AcpThreadMetadata}; use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; -use futures::channel::mpsc::UnboundedReceiver; use gpui::{Entity, SharedString, Task}; use project::Project; use serde::{Deserialize, Serialize}; @@ -27,6 +26,8 @@ pub trait AgentConnection { cx: &mut App, ) -> Task>>; + // todo!(expose a history trait, and include list_threads and load_thread) + // todo!(write a test) fn list_threads( &self, _cx: &mut App, diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 403a59e51b..6de5445d80 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -5,16 +5,15 @@ use crate::{ OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, }; -use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id}; +use crate::{ThreadsDatabase, generate_session_id}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; use anyhow::{Context as _, Result, anyhow}; use collections::{HashSet, IndexMap}; use fs::Fs; -use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; -use futures::future::Shared; -use futures::{SinkExt, StreamExt, future}; +use futures::channel::mpsc; +use futures::{StreamExt, future}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, }; @@ -30,6 +29,7 @@ use std::collections::HashMap; use std::path::Path; use std::rc::Rc; use std::sync::Arc; +use std::time::Duration; use util::ResultExt; const RULES_FILE_NAMES: [&'static str; 9] = [ @@ -174,7 +174,7 @@ pub struct NativeAgent { prompt_store: Option>, thread_database: Arc, history: watch::Sender>>, - load_history: Task>, + load_history: Task<()>, fs: Arc, _subscriptions: Vec, } @@ -212,7 +212,7 @@ impl NativeAgent { let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = watch::channel(()); - let this = Self { + let mut this = Self { sessions: HashMap::new(), project_context: Rc::new(RefCell::new(project_context)), project_context_needs_refresh: project_context_needs_refresh_tx, @@ -229,7 +229,7 @@ impl NativeAgent { prompt_store, fs, history: watch::channel(None).0, - load_history: Task::ready(Ok(())), + load_history: Task::ready(()), _subscriptions: subscriptions, }; this.reload_history(cx); @@ -249,7 +249,7 @@ impl NativeAgent { Session { thread: thread.clone(), acp_thread: acp_thread.downgrade(), - save_task: Task::ready(()), + save_task: Task::ready(Ok(())), _subscriptions: vec![ cx.observe_release(&acp_thread, |this, acp_thread, _cx| { this.sessions.remove(acp_thread.session_id()); @@ -280,24 +280,30 @@ impl NativeAgent { } fn reload_history(&mut self, cx: &mut Context) { + dbg!(""); let thread_database = self.thread_database.clone(); self.load_history = cx.spawn(async move |this, cx| { let results = cx .background_spawn(async move { let results = thread_database.list_threads().await?; - Ok(results - .into_iter() - .map(|thread| AcpThreadMetadata { - agent: NATIVE_AGENT_SERVER_NAME.clone(), - id: thread.id.into(), - title: thread.title, - updated_at: thread.updated_at, - }) - .collect()) + dbg!(&results); + anyhow::Ok( + results + .into_iter() + .map(|thread| AcpThreadMetadata { + agent: NATIVE_AGENT_SERVER_NAME.clone(), + id: thread.id.into(), + title: thread.title, + updated_at: thread.updated_at, + }) + .collect(), + ) }) - .await?; - this.update(cx, |this, cx| this.history.send(Some(results)))?; - anyhow::Ok(()) + .await; + if let Some(results) = results.log_err() { + this.update(cx, |this, _| this.history.send(Some(results))) + .ok(); + } }); } @@ -509,10 +515,10 @@ impl NativeAgent { ) { self.models.refresh_list(cx); for session in self.sessions.values_mut() { - session.thread.update(cx, |thread, _| { + session.thread.update(cx, |thread, cx| { let model_id = LanguageModels::model_id(&thread.model()); if let Some(model) = self.models.model_from_id(&model_id) { - thread.set_model(model.clone()); + thread.set_model(model.clone(), cx); } }); } @@ -715,8 +721,8 @@ impl AgentModelSelector for NativeAgentConnection { return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); }; - thread.update(cx, |thread, _cx| { - thread.set_model(model.clone()); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); }); update_settings_file::( @@ -867,12 +873,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, ) -> Task>> { - let thread_id = ThreadId::from(session_id.clone()); let database = self.0.update(cx, |this, _| this.thread_database.clone()); cx.spawn(async move |cx| { - let database = database.await.map_err(|e| anyhow!(e))?; let db_thread = database - .load_thread(thread_id.clone()) + .load_thread(session_id.clone()) .await? .context("no such thread found")?; @@ -915,7 +919,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let thread = cx.new(|cx| { let mut thread = Thread::from_db( - thread_id, + session_id, db_thread, project.clone(), agent.project_context.clone(), @@ -934,7 +938,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Store the session agent.update(cx, |agent, cx| { - agent.insert_session(session_id, thread, acp_thread, cx) + agent.insert_session(thread.clone(), acp_thread.clone(), cx) })?; let events = thread.update(cx, |thread, cx| thread.replay(cx))?; @@ -995,7 +999,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::info!("Cancelling on session: {}", session_id); self.0.update(cx, |agent, cx| { if let Some(agent) = agent.sessions.get(session_id) { - agent.thread.update(cx, |thread, _cx| thread.cancel()); + agent.thread.update(cx, |thread, cx| thread.cancel(cx)); } }); } @@ -1022,7 +1026,10 @@ struct NativeAgentSessionEditor(Entity); impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { - Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) + Task::ready( + self.0 + .update(cx, |thread, cx| thread.truncate(message_id, cx)), + ) } } diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index 5332da276f..a7240df5c7 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -1,4 +1,4 @@ -use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent}; +use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; use agent::thread_store; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; @@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbThreadMetadata { - pub id: ThreadId, + pub id: acp::SessionId, #[serde(alias = "summary")] pub title: SharedString, pub updated_at: DateTime, @@ -323,7 +323,7 @@ impl ThreadsDatabase { for (id, summary, updated_at) in rows { threads.push(DbThreadMetadata { - id: ThreadId(id), + id: acp::SessionId(id), title: summary.into(), updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), }); @@ -333,7 +333,7 @@ impl ThreadsDatabase { }) } - pub fn load_thread(&self, id: ThreadId) -> Task>> { + pub fn load_thread(&self, id: acp::SessionId) -> Task>> { let connection = self.connection.clone(); self.executor.spawn(async move { diff --git a/crates/agent2/src/history_store.rs b/crates/agent2/src/history_store.rs index fb4f34f9c5..f4e53c4c23 100644 --- a/crates/agent2/src/history_store.rs +++ b/crates/agent2/src/history_store.rs @@ -1,17 +1,13 @@ use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName}; -use agent::{ThreadId, thread_store::ThreadStore}; use agent_client_protocol as acp; use anyhow::{Context as _, Result}; use assistant_context::SavedContextMetadata; use chrono::{DateTime, Utc}; use collections::HashMap; -use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*}; -use itertools::Itertools; -use paths::contexts_dir; +use gpui::{SharedString, Task, prelude::*}; use serde::{Deserialize, Serialize}; use smol::stream::StreamExt; -use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration}; -use util::ResultExt as _; +use std::{path::Path, sync::Arc, time::Duration}; const MAX_RECENTLY_OPENED_ENTRIES: usize = 6; const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json"; @@ -64,16 +60,16 @@ enum SerializedRecentOpen { } pub struct AgentHistory { - entries: HashMap, - _task: Task>, + entries: watch::Receiver>>, + _task: Task<()>, } pub struct HistoryStore { - agents: HashMap, + agents: HashMap, // todo!() text threads } impl HistoryStore { - pub fn new(cx: &mut Context) -> Self { + pub fn new(_cx: &mut Context) -> Self { Self { agents: HashMap::default(), } @@ -88,33 +84,18 @@ impl HistoryStore { let Some(mut history) = connection.list_threads(cx) else { return; }; - let task = cx.spawn(async move |this, cx| { - while let Some(updated_history) = history.next().await { - dbg!(&updated_history); - this.update(cx, |this, cx| { - for entry in updated_history { - let agent = this - .agents - .get_mut(&entry.agent) - .context("agent not found")?; - agent.entries.insert(entry.id.clone(), entry); - } - cx.notify(); - anyhow::Ok(()) - })?? - } - Ok(()) - }); - self.agents.insert( - agent_name, - AgentHistory { - entries: Default::default(), - _task: task, - }, - ); + let history = AgentHistory { + entries: history.clone(), + _task: cx.spawn(async move |this, cx| { + while history.changed().await.is_ok() { + this.update(cx, |_, cx| cx.notify()).ok(); + } + }), + }; + self.agents.insert(agent_name.clone(), history); } - pub fn entries(&self, cx: &mut Context) -> Vec { + pub fn entries(&mut self, _cx: &mut Context) -> Vec { let mut history_entries = Vec::new(); #[cfg(debug_assertions)] @@ -124,9 +105,8 @@ impl HistoryStore { history_entries.extend( self.agents - .values() - .flat_map(|agent| agent.entries.values()) - .cloned() + .values_mut() + .flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?") .map(HistoryEntry::Thread), ); // todo!() include the text threads in here. @@ -135,7 +115,7 @@ impl HistoryStore { history_entries } - pub fn recent_entries(&self, limit: usize, cx: &mut Context) -> Vec { + pub fn recent_entries(&mut self, limit: usize, cx: &mut Context) -> Vec { self.entries(cx).into_iter().take(limit).collect() } } diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 75a21a2baa..2a4d306290 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -938,7 +938,7 @@ async fn test_cancellation(cx: &mut TestAppContext) { // Cancel the current send and ensure that the event stream is closed, even // if one of the tools is still running. - thread.update(cx, |thread, _cx| thread.cancel()); + thread.update(cx, |thread, cx| thread.cancel(cx)); let events = events.collect::>().await; let last_event = events.last(); assert!( @@ -1113,7 +1113,7 @@ async fn test_truncate(cx: &mut TestAppContext) { }); thread - .update(cx, |thread, _cx| thread.truncate(message_id)) + .update(cx, |thread, cx| thread.truncate(message_id, cx)) .unwrap(); cx.run_until_parked(); thread.read_with(cx, |thread, _| { diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index ec820c7b5f..7ea5ff7cc6 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -802,16 +802,18 @@ impl Thread { &self.model } - pub fn set_model(&mut self, model: Arc) { + pub fn set_model(&mut self, model: Arc, cx: &mut Context) { self.model = model; + cx.notify() } pub fn completion_mode(&self) -> CompletionMode { self.completion_mode } - pub fn set_completion_mode(&mut self, mode: CompletionMode) { + pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context) { self.completion_mode = mode; + cx.notify() } #[cfg(any(test, feature = "test-support"))] @@ -839,21 +841,22 @@ impl Thread { self.profile_id = profile_id; } - pub fn cancel(&mut self) { + pub fn cancel(&mut self, cx: &mut Context) { if let Some(running_turn) = self.running_turn.take() { running_turn.cancel(); } - self.flush_pending_message(); + self.flush_pending_message(cx); } - pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { - self.cancel(); + pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { + self.cancel(cx); let Some(position) = self.messages.iter().position( |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), ) else { return Err(anyhow!("Message not found")); }; self.messages.truncate(position); + cx.notify(); Ok(()) } @@ -900,7 +903,7 @@ impl Thread { } fn run_turn(&mut self, cx: &mut Context) -> mpsc::UnboundedReceiver> { - self.cancel(); + self.cancel(cx); let model = self.model.clone(); let (events_tx, events_rx) = mpsc::unbounded::>(); @@ -938,8 +941,8 @@ impl Thread { LanguageModelCompletionEvent::Stop(reason) => { event_stream.send_stop(reason); if reason == StopReason::Refusal { - this.update(cx, |this, _cx| { - this.flush_pending_message(); + this.update(cx, |this, cx| { + this.flush_pending_message(cx); this.messages.truncate(message_ix); })?; return Ok(()); @@ -991,7 +994,7 @@ impl Thread { log::info!("No tool uses found, completing turn"); return Ok(()); } else { - this.update(cx, |this, _| this.flush_pending_message())?; + this.update(cx, |this, cx| this.flush_pending_message(cx))?; completion_intent = CompletionIntent::ToolResults; } } @@ -1005,8 +1008,8 @@ impl Thread { log::info!("Turn execution completed successfully"); } - this.update(cx, |this, _| { - this.flush_pending_message(); + this.update(cx, |this, cx| { + this.flush_pending_message(cx); this.running_turn.take(); }) .ok(); @@ -1046,7 +1049,7 @@ impl Thread { match event { StartMessage { .. } => { - self.flush_pending_message(); + self.flush_pending_message(cx); self.pending_message = Some(AgentMessage::default()); } Text(new_text) => self.handle_text_event(new_text, event_stream, cx), @@ -1255,7 +1258,7 @@ impl Thread { self.pending_message.get_or_insert_default() } - fn flush_pending_message(&mut self) { + fn flush_pending_message(&mut self, cx: &mut Context) { let Some(mut message) = self.pending_message.take() else { return; }; @@ -1280,6 +1283,7 @@ impl Thread { } self.messages.push(Message::Agent(message)); + cx.notify() } pub(crate) fn build_completion_request( diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 38036ad3c4..40517e49a0 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -2487,12 +2487,15 @@ impl AcpThreadView { return; }; - thread.update(cx, |thread, _cx| { + thread.update(cx, |thread, cx| { let current_mode = thread.completion_mode(); - thread.set_completion_mode(match current_mode { - CompletionMode::Burn => CompletionMode::Normal, - CompletionMode::Normal => CompletionMode::Burn, - }); + thread.set_completion_mode( + match current_mode { + CompletionMode::Burn => CompletionMode::Normal, + CompletionMode::Normal => CompletionMode::Burn, + }, + cx, + ); }); } @@ -3274,8 +3277,8 @@ impl AcpThreadView { .tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use.")) .on_click({ cx.listener(move |this, _, _window, cx| { - thread.update(cx, |thread, _cx| { - thread.set_completion_mode(CompletionMode::Burn); + thread.update(cx, |thread, cx| { + thread.set_completion_mode(CompletionMode::Burn, cx); }); this.resume_chat(cx); })