From 0201d1e0b4867d55209b34c65b80ec4ba4c06a65 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Fri, 23 May 2025 14:55:08 +0200 Subject: [PATCH] agent: Unfollow agent on completion cancellation (#31258) Handle unfollowing agent and clearing agent location when completion is canceled. Release Notes: - N/A --- crates/agent/src/active_thread.rs | 179 +++++++++++++++++++++++++++++- 1 file changed, 177 insertions(+), 2 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 8229df3541..2dbd97edf0 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -52,7 +52,7 @@ use ui::{ }; use util::ResultExt as _; use util::markdown::MarkdownCodeBlock; -use workspace::Workspace; +use workspace::{CollaboratorId, Workspace}; use zed_actions::assistant::OpenRulesLibrary; pub struct ActiveThread { @@ -971,7 +971,22 @@ impl ActiveThread { ThreadEvent::ShowError(error) => { self.last_error = Some(error.clone()); } - ThreadEvent::NewRequest | ThreadEvent::CompletionCanceled => { + ThreadEvent::NewRequest => { + cx.notify(); + } + ThreadEvent::CompletionCanceled => { + self.thread.update(cx, |thread, cx| { + thread.project().update(cx, |project, cx| { + project.set_agent_location(None, cx); + }) + }); + self.workspace + .update(cx, |workspace, cx| { + if workspace.is_being_followed(CollaboratorId::Agent) { + workspace.unfollow(CollaboratorId::Agent, window, cx); + } + }) + .ok(); cx.notify(); } ThreadEvent::StreamedCompletion @@ -3593,3 +3608,163 @@ fn open_editor_at_position( } }) } + +#[cfg(test)] +mod tests { + use assistant_tool::{ToolRegistry, ToolWorkingSet}; + use editor::EditorSettings; + use fs::FakeFs; + use gpui::{AppContext, TestAppContext, VisualTestContext}; + use language_model::{LanguageModel, fake_provider::FakeLanguageModel}; + use project::Project; + use prompt_store::PromptBuilder; + use serde_json::json; + use settings::SettingsStore; + use util::path; + use workspace::CollaboratorId; + + use crate::{ContextLoadResult, thread_store}; + + use super::*; + + #[gpui::test] + async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (cx, _active_thread, workspace, thread, model) = + setup_test_environment(cx, project.clone()).await; + + // Insert user message without any context (empty context vector) + thread.update(cx, |thread, cx| { + thread.insert_user_message( + "What is the best way to learn Rust?", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + }); + + // Stream response to user message + thread.update(cx, |thread, cx| { + let request = thread.to_completion_request(model.clone(), cx); + thread.stream_completion(request, model, cx.active_window(), cx) + }); + // Follow the agent + cx.update(|window, cx| { + workspace.update(cx, |workspace, cx| { + workspace.follow(CollaboratorId::Agent, window, cx); + }) + }); + assert!(cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent))); + + // Cancel the current completion + thread.update(cx, |thread, cx| { + thread.cancel_last_completion(cx.active_window(), cx) + }); + + cx.executor().run_until_parked(); + + // No longer following the agent + assert!(!cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent))); + } + + fn init_test_settings(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + AssistantSettings::register(cx); + prompt_store::init(cx); + thread_store::init(cx); + workspace::init_settings(cx); + language_model::init_settings(cx); + ThemeSettings::register(cx); + EditorSettings::register(cx); + ToolRegistry::default_global(cx); + }); + } + + // Helper to create a test project with test files + async fn create_test_project( + cx: &mut TestAppContext, + files: serde_json::Value, + ) -> Entity { + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/test"), files).await; + Project::test(fs, [path!("/test").as_ref()], cx).await + } + + async fn setup_test_environment( + cx: &mut TestAppContext, + project: Entity, + ) -> ( + &mut VisualTestContext, + Entity, + Entity, + Entity, + Arc, + ) { + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + + let thread_store = cx + .update(|_, cx| { + ThreadStore::load( + project.clone(), + cx.new(|_| ToolWorkingSet::default()), + None, + Arc::new(PromptBuilder::new(None).unwrap()), + cx, + ) + }) + .await + .unwrap(); + + let text_thread_store = cx + .update(|_, cx| { + TextThreadStore::new( + project.clone(), + Arc::new(PromptBuilder::new(None).unwrap()), + Default::default(), + cx, + ) + }) + .await + .unwrap(); + + let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); + let context_store = + cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade()))); + + let model = FakeLanguageModel::default(); + let model: Arc = Arc::new(model); + + let language_registry = LanguageRegistry::new(cx.executor()); + let language_registry = Arc::new(language_registry); + + let active_thread = cx.update(|window, cx| { + cx.new(|cx| { + ActiveThread::new( + thread.clone(), + thread_store.clone(), + text_thread_store, + context_store.clone(), + language_registry.clone(), + workspace.downgrade(), + window, + cx, + ) + }) + }); + + (cx, active_thread, workspace, thread, model) + } +}