agent: Unfollow agent on completion cancellation (#31258)

Handle unfollowing agent and clearing agent location when completion
is canceled.

Release Notes:

- N/A
This commit is contained in:
Ben Brandt 2025-05-23 14:55:08 +02:00 committed by GitHub
parent f48b6b583e
commit 0201d1e0b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -52,7 +52,7 @@ use ui::{
};
use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
use workspace::Workspace;
use workspace::{CollaboratorId, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
pub struct ActiveThread {
@ -971,7 +971,22 @@ impl ActiveThread {
ThreadEvent::ShowError(error) => {
self.last_error = Some(error.clone());
}
ThreadEvent::NewRequest | ThreadEvent::CompletionCanceled => {
ThreadEvent::NewRequest => {
cx.notify();
}
ThreadEvent::CompletionCanceled => {
self.thread.update(cx, |thread, cx| {
thread.project().update(cx, |project, cx| {
project.set_agent_location(None, cx);
})
});
self.workspace
.update(cx, |workspace, cx| {
if workspace.is_being_followed(CollaboratorId::Agent) {
workspace.unfollow(CollaboratorId::Agent, window, cx);
}
})
.ok();
cx.notify();
}
ThreadEvent::StreamedCompletion
@ -3593,3 +3608,163 @@ fn open_editor_at_position(
}
})
}
#[cfg(test)]
mod tests {
use assistant_tool::{ToolRegistry, ToolWorkingSet};
use editor::EditorSettings;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext, VisualTestContext};
use language_model::{LanguageModel, fake_provider::FakeLanguageModel};
use project::Project;
use prompt_store::PromptBuilder;
use serde_json::json;
use settings::SettingsStore;
use util::path;
use workspace::CollaboratorId;
use crate::{ContextLoadResult, thread_store};
use super::*;
#[gpui::test]
async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (cx, _active_thread, workspace, thread, model) =
setup_test_environment(cx, project.clone()).await;
// Insert user message without any context (empty context vector)
thread.update(cx, |thread, cx| {
thread.insert_user_message(
"What is the best way to learn Rust?",
ContextLoadResult::default(),
None,
vec![],
cx,
);
});
// Stream response to user message
thread.update(cx, |thread, cx| {
let request = thread.to_completion_request(model.clone(), cx);
thread.stream_completion(request, model, cx.active_window(), cx)
});
// Follow the agent
cx.update(|window, cx| {
workspace.update(cx, |workspace, cx| {
workspace.follow(CollaboratorId::Agent, window, cx);
})
});
assert!(cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
// Cancel the current completion
thread.update(cx, |thread, cx| {
thread.cancel_last_completion(cx.active_window(), cx)
});
cx.executor().run_until_parked();
// No longer following the agent
assert!(!cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
AssistantSettings::register(cx);
prompt_store::init(cx);
thread_store::init(cx);
workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx);
EditorSettings::register(cx);
ToolRegistry::default_global(cx);
});
}
// Helper to create a test project with test files
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
async fn setup_test_environment(
cx: &mut TestAppContext,
project: Entity<Project>,
) -> (
&mut VisualTestContext,
Entity<ActiveThread>,
Entity<Workspace>,
Entity<Thread>,
Arc<dyn LanguageModel>,
) {
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})
.await
.unwrap();
let text_thread_store = cx
.update(|_, cx| {
TextThreadStore::new(
project.clone(),
Arc::new(PromptBuilder::new(None).unwrap()),
Default::default(),
cx,
)
})
.await
.unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store =
cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
let model = FakeLanguageModel::default();
let model: Arc<dyn LanguageModel> = Arc::new(model);
let language_registry = LanguageRegistry::new(cx.executor());
let language_registry = Arc::new(language_registry);
let active_thread = cx.update(|window, cx| {
cx.new(|cx| {
ActiveThread::new(
thread.clone(),
thread_store.clone(),
text_thread_store,
context_store.clone(),
language_registry.clone(),
workspace.downgrade(),
window,
cx,
)
})
});
(cx, active_thread, workspace, thread, model)
}
}