Stop generating in the Agent panel when the user edits a previous message (#29915)

Otherwise the panel keeps scrolling as the new token comes in and it is
almost impossible to keep the scroll position in the right place.

Also, if the user is editing, it is likely that the current generated
tokens will need to be regenerated anyway, so we may as well stop the
current progress.

Release Notes:

- Agent Beta: Stop generating tokens if previous messages are edited.
This commit is contained in:
Ben Brandt 2025-05-05 14:06:02 +02:00 committed by GitHub
parent 251f26d48a
commit ce053c9bff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1241,6 +1241,9 @@ impl ActiveThread {
return;
};
// Cancel any ongoing streaming when user starts editing a previous message
self.cancel_last_completion(window, cx);
let editor = crate::message_editor::create_editor(
self.workspace.clone(),
self.context_store.downgrade(),
@ -3464,3 +3467,146 @@ fn open_editor_at_position(
}
})
}
#[cfg(test)]
mod tests {
use assistant_tool::{ToolRegistry, ToolWorkingSet};
use context_server::ContextServerSettings;
use editor::EditorSettings;
use fs::FakeFs;
use gpui::{TestAppContext, VisualTestContext};
use language_model::{LanguageModel, fake_provider::FakeLanguageModel};
use project::Project;
use prompt_store::PromptBuilder;
use serde_json::json;
use settings::SettingsStore;
use util::path;
use crate::{ContextLoadResult, thread_store};
use super::*;
#[gpui::test]
async fn test_current_completion_cancelled_when_message_edited(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (cx, active_thread, thread, model) = setup_test_environment(cx, project.clone()).await;
// Insert user message without any context (empty context vector)
let message = thread.update(cx, |thread, cx| {
let message_id = thread.insert_user_message(
"What is the best way to learn Rust?",
ContextLoadResult::default(),
None,
vec![],
cx,
);
thread
.message(message_id)
.expect("message should exist")
.clone()
});
// Stream response to user message
thread.update(cx, |thread, cx| {
let request = thread.to_completion_request(model.clone(), cx);
thread.stream_completion(request, model, cx.active_window(), cx)
});
let generating = thread.update(cx, |thread, _cx| thread.is_generating());
assert!(generating, "There should be one pending completion");
// Edit the previous message
active_thread.update_in(cx, |active_thread, window, cx| {
active_thread.start_editing_message(message.id, &message.segments, window, cx);
});
// Check that the stream was cancelled
let generating = thread.update(cx, |thread, _cx| thread.is_generating());
assert!(!generating, "The completion should have been cancelled");
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
AssistantSettings::register(cx);
prompt_store::init(cx);
thread_store::init(cx);
workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
ToolRegistry::default_global(cx);
});
}
// Helper to create a test project with test files
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
async fn setup_test_environment(
cx: &mut TestAppContext,
project: Entity<Project>,
) -> (
&mut VisualTestContext,
Entity<ActiveThread>,
Entity<Thread>,
Arc<dyn LanguageModel>,
) {
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})
.await
.unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
let model = FakeLanguageModel::default();
let model: Arc<dyn LanguageModel> = Arc::new(model);
let language_registry = LanguageRegistry::new(cx.executor());
let language_registry = Arc::new(language_registry);
let active_thread = cx.update(|window, cx| {
cx.new(|cx| {
ActiveThread::new(
thread.clone(),
thread_store.clone(),
context_store.clone(),
language_registry.clone(),
workspace.downgrade(),
window,
cx,
)
})
});
(cx, active_thread, thread, model)
}
}