agent2: Always finalize diffs from the edit tool (#36918)

Previously, we wouldn't finalize the diff if an error occurred during
editing or the tool call was canceled.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Ben Brandt 2025-08-26 02:46:29 -07:00 committed by GitHub
parent c14d84cfdb
commit b249593abe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 152 additions and 6 deletions

View file

@ -2459,6 +2459,30 @@ impl ToolCallEventStreamReceiver {
} }
} }
pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
update,
)))) = event
{
update.fields
} else {
panic!("Expected update fields but got: {:?}", event);
}
}
pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
update,
)))) = event
{
update.diff
} else {
panic!("Expected diff but got: {:?}", event);
}
}
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> { pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
let event = self.0.next().await; let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal( if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(

View file

@ -273,6 +273,13 @@ impl AgentTool for EditFileTool {
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?; let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
event_stream.update_diff(diff.clone()); event_stream.update_diff(diff.clone());
let _finalize_diff = util::defer({
let diff = diff.downgrade();
let mut cx = cx.clone();
move || {
diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
}
});
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let old_text = cx let old_text = cx
@ -389,8 +396,6 @@ impl AgentTool for EditFileTool {
}) })
.await; .await;
diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
let input_path = input.path.display(); let input_path = input.path.display();
if unified_diff.is_empty() { if unified_diff.is_empty() {
anyhow::ensure!( anyhow::ensure!(
@ -1545,6 +1550,100 @@ mod tests {
); );
} }
#[gpui::test]
async fn test_diff_finalization(cx: &mut TestAppContext) {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/", json!({"main.rs": ""})).await;
let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
let languages = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry.clone(),
Templates::new(),
Some(model.clone()),
cx,
)
});
// Ensure the diff is finalized after the edit completes.
{
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
EditFileToolInput {
display_description: "Edit file".into(),
path: path!("/main.rs").into(),
mode: EditFileMode::Edit,
},
stream_tx,
cx,
)
});
stream_rx.expect_update_fields().await;
let diff = stream_rx.expect_diff().await;
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
cx.run_until_parked();
model.end_last_completion_stream();
edit.await.unwrap();
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
}
// Ensure the diff is finalized if an error occurs while editing.
{
model.forbid_requests();
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
EditFileToolInput {
display_description: "Edit file".into(),
path: path!("/main.rs").into(),
mode: EditFileMode::Edit,
},
stream_tx,
cx,
)
});
stream_rx.expect_update_fields().await;
let diff = stream_rx.expect_diff().await;
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
edit.await.unwrap_err();
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
model.allow_requests();
}
// Ensure the diff is finalized if the tool call gets dropped.
{
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
EditFileToolInput {
display_description: "Edit file".into(),
path: path!("/main.rs").into(),
mode: EditFileMode::Edit,
},
stream_tx,
cx,
)
});
stream_rx.expect_update_fields().await;
let diff = stream_rx.expect_diff().await;
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
drop(edit);
cx.run_until_parked();
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
}
}
fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| { cx.update(|cx| {
let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);

View file

@ -4,12 +4,16 @@ use crate::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelRequest, LanguageModelToolChoice,
}; };
use anyhow::anyhow;
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result; use http_client::Result;
use parking_lot::Mutex; use parking_lot::Mutex;
use smol::stream::StreamExt; use smol::stream::StreamExt;
use std::sync::Arc; use std::sync::{
Arc,
atomic::{AtomicBool, Ordering::SeqCst},
};
#[derive(Clone)] #[derive(Clone)]
pub struct FakeLanguageModelProvider { pub struct FakeLanguageModelProvider {
@ -106,6 +110,7 @@ pub struct FakeLanguageModel {
>, >,
)>, )>,
>, >,
forbid_requests: AtomicBool,
} }
impl Default for FakeLanguageModel { impl Default for FakeLanguageModel {
@ -114,11 +119,20 @@ impl Default for FakeLanguageModel {
provider_id: LanguageModelProviderId::from("fake".to_string()), provider_id: LanguageModelProviderId::from("fake".to_string()),
provider_name: LanguageModelProviderName::from("Fake".to_string()), provider_name: LanguageModelProviderName::from("Fake".to_string()),
current_completion_txs: Mutex::new(Vec::new()), current_completion_txs: Mutex::new(Vec::new()),
forbid_requests: AtomicBool::new(false),
} }
} }
} }
impl FakeLanguageModel { impl FakeLanguageModel {
pub fn allow_requests(&self) {
self.forbid_requests.store(false, SeqCst);
}
pub fn forbid_requests(&self) {
self.forbid_requests.store(true, SeqCst);
}
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> { pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs self.current_completion_txs
.lock() .lock()
@ -251,9 +265,18 @@ impl LanguageModel for FakeLanguageModel {
LanguageModelCompletionError, LanguageModelCompletionError,
>, >,
> { > {
let (tx, rx) = mpsc::unbounded(); if self.forbid_requests.load(SeqCst) {
self.current_completion_txs.lock().push((request, tx)); async move {
async move { Ok(rx.boxed()) }.boxed() Err(LanguageModelCompletionError::Other(anyhow!(
"requests are forbidden"
)))
}
.boxed()
} else {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move { Ok(rx.boxed()) }.boxed()
}
} }
fn as_fake(&self) -> &Self { fn as_fake(&self) -> &Self {