From 4ee565cd392b0563206eb2d2e61be214fa57ba03 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 20 Aug 2025 14:03:20 +0200 Subject: [PATCH] Fix mentions roundtrip from/to database and other history bugs (#36575) Release Notes: - N/A --- crates/agent2/src/agent.rs | 170 +++++++++++++++++++++++++++++++++++- crates/agent2/src/thread.rs | 58 ++++++------ 2 files changed, 200 insertions(+), 28 deletions(-) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 5496ecea7b..1fa307511f 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -577,6 +577,10 @@ impl NativeAgent { } fn save_thread(&mut self, thread: Entity, cx: &mut Context) { + if thread.read(cx).is_empty() { + return; + } + let database_future = ThreadsDatabase::connect(cx); let (id, db_thread) = thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx))); @@ -989,12 +993,19 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume { #[cfg(test)] mod tests { + use crate::HistoryEntryId; + use super::*; - use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo}; + use acp_thread::{ + AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri, + }; use fs::FakeFs; use gpui::TestAppContext; + use indoc::indoc; + use language_model::fake_provider::FakeLanguageModel; use serde_json::json; use settings::SettingsStore; + use util::path; #[gpui::test] async fn test_maintaining_project_context(cx: &mut TestAppContext) { @@ -1179,6 +1190,163 @@ mod tests { ); } + #[gpui::test] + #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows + async fn test_save_load_thread(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "b.md": "Lorem" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let agent = NativeAgent::new( + project.clone(), + history_store.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_thread(project.clone(), Path::new(""), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + + // Ensure empty threads are not saved, even if they get mutated. + let model = Arc::new(FakeLanguageModel::default()); + let summary_model = Arc::new(FakeLanguageModel::default()); + thread.update(cx, |thread, cx| { + thread.set_model(model, cx); + thread.set_summarization_model(Some(summary_model), cx); + }); + cx.run_until_parked(); + assert_eq!(history_entries(&history_store, cx), vec![]); + + let model = thread.read_with(cx, |thread, _| thread.model().unwrap().clone()); + let model = model.as_fake(); + let summary_model = thread.read_with(cx, |thread, _| { + thread.summarization_model().unwrap().clone() + }); + let summary_model = summary_model.as_fake(); + let send = acp_thread.update(cx, |thread, cx| { + thread.send( + vec![ + "What does ".into(), + acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: "b.md".into(), + uri: MentionUri::File { + abs_path: path!("/a/b.md").into(), + } + .to_uri() + .to_string(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }), + " mean?".into(), + ], + cx, + ) + }); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + + model.send_last_completion_stream_text_chunk("Lorem."); + model.end_last_completion_stream(); + cx.run_until_parked(); + summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md"); + summary_model.end_last_completion_stream(); + + send.await.unwrap(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + What does [@b.md](file:///a/b.md) mean? + + ## Assistant + + Lorem. + + "} + ) + }); + + // Drop the ACP thread, which should cause the session to be dropped as well. + cx.update(|_| { + drop(thread); + drop(acp_thread); + }); + agent.read_with(cx, |agent, _| { + assert_eq!(agent.sessions.keys().cloned().collect::>(), []); + }); + + // Ensure the thread can be reloaded from disk. + assert_eq!( + history_entries(&history_store, cx), + vec![( + HistoryEntryId::AcpThread(session_id.clone()), + "Explaining /a/b.md".into() + )] + ); + let acp_thread = agent + .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx)) + .await + .unwrap(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User + + What does [@b.md](file:///a/b.md) mean? + + ## Assistant + + Lorem. + + "} + ) + }); + } + + fn history_entries( + history: &Entity, + cx: &mut TestAppContext, + ) -> Vec<(HistoryEntryId, String)> { + history.read_with(cx, |history, cx| { + history + .entries(cx) + .iter() + .map(|e| (e.id(), e.title().to_string())) + .collect::>() + }) + } + fn init_test(cx: &mut TestAppContext) { env_logger::try_init().ok(); cx.update(|cx| { diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index cd97fa2060..c7b1a08b92 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -720,7 +720,7 @@ impl Thread { pub fn to_db(&self, cx: &App) -> Task { let initial_project_snapshot = self.initial_project_snapshot.clone(); let mut thread = DbThread { - title: self.title.clone().unwrap_or_default(), + title: self.title(), messages: self.messages.clone(), updated_at: self.updated_at, detailed_summary: self.summary.clone(), @@ -870,6 +870,10 @@ impl Thread { &self.action_log } + pub fn is_empty(&self) -> bool { + self.messages.is_empty() && self.title.is_none() + } + pub fn model(&self) -> Option<&Arc> { self.model.as_ref() } @@ -884,6 +888,10 @@ impl Thread { cx.notify() } + pub fn summarization_model(&self) -> Option<&Arc> { + self.summarization_model.as_ref() + } + pub fn set_summarization_model( &mut self, model: Option>, @@ -1068,6 +1076,7 @@ impl Thread { event_stream: event_stream.clone(), _task: cx.spawn(async move |this, cx| { log::info!("Starting agent turn execution"); + let mut update_title = None; let turn_result: Result = async { let mut completion_intent = CompletionIntent::UserPrompt; loop { @@ -1122,10 +1131,15 @@ impl Thread { this.pending_message() .tool_results .insert(tool_result.tool_use_id.clone(), tool_result); - }) - .ok(); + })?; } + this.update(cx, |this, cx| { + if this.title.is_none() && update_title.is_none() { + update_title = Some(this.update_title(&event_stream, cx)); + } + })?; + if tool_use_limit_reached { log::info!("Tool use limit reached, completing turn"); this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?; @@ -1146,10 +1160,6 @@ impl Thread { Ok(reason) => { log::info!("Turn execution completed: {:?}", reason); - let update_title = this - .update(cx, |this, cx| this.update_title(&event_stream, cx)) - .ok() - .flatten(); if let Some(update_title) = update_title { update_title.await.context("update title failed").log_err(); } @@ -1593,17 +1603,14 @@ impl Thread { &mut self, event_stream: &ThreadEventStream, cx: &mut Context, - ) -> Option>> { - if self.title.is_some() { - log::debug!("Skipping title generation because we already have one."); - return None; - } - + ) -> Task> { log::info!( "Generating title with model: {:?}", self.summarization_model.as_ref().map(|model| model.name()) ); - let model = self.summarization_model.clone()?; + let Some(model) = self.summarization_model.clone() else { + return Task::ready(Ok(())); + }; let event_stream = event_stream.clone(); let mut request = LanguageModelRequest { intent: Some(CompletionIntent::ThreadSummarization), @@ -1620,7 +1627,7 @@ impl Thread { content: vec![SUMMARIZE_THREAD_PROMPT.into()], cache: false, }); - Some(cx.spawn(async move |this, cx| { + cx.spawn(async move |this, cx| { let mut title = String::new(); let mut messages = model.stream_completion(request, cx).await?; while let Some(event) = messages.next().await { @@ -1655,7 +1662,7 @@ impl Thread { this.title = Some(title); cx.notify(); }) - })) + }) } fn last_user_message(&self) -> Option<&UserMessage> { @@ -2457,18 +2464,15 @@ impl From for acp::ContentBlock { uri: None, }), UserMessageContent::Mention { uri, content } => { - acp::ContentBlock::ResourceLink(acp::ResourceLink { - uri: uri.to_uri().to_string(), - name: uri.name(), + acp::ContentBlock::Resource(acp::EmbeddedResource { + resource: acp::EmbeddedResourceResource::TextResourceContents( + acp::TextResourceContents { + mime_type: None, + text: content, + uri: uri.to_uri().to_string(), + }, + ), annotations: None, - description: if content.is_empty() { - None - } else { - Some(content) - }, - mime_type: None, - size: None, - title: None, }) } }