diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 1b1c014b79..4acd72f275 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -2459,6 +2459,30 @@ impl ToolCallEventStreamReceiver { } } + pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields { + let event = self.0.next().await; + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + update, + )))) = event + { + update.fields + } else { + panic!("Expected update fields but got: {:?}", event); + } + } + + pub async fn expect_diff(&mut self) -> Entity { + let event = self.0.next().await; + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff( + update, + )))) = event + { + update.diff + } else { + panic!("Expected diff but got: {:?}", event); + } + } + pub async fn expect_terminal(&mut self) -> Entity { let event = self.0.next().await; if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal( diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 5a80a9428f..3288233dfe 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -272,164 +272,164 @@ impl AgentTool for EditFileTool { .await?; let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?; - let result = async { - event_stream.update_diff(diff.clone()); + event_stream.update_diff(diff.clone()); + let _finalize_diff = util::defer({ + let diff = diff.downgrade(); + let mut cx = cx.clone(); + move || { + diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok(); + } + }); - let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let old_text = cx - .background_spawn({ - let old_snapshot = old_snapshot.clone(); - async move { Arc::new(old_snapshot.text()) } - }) - .await; + let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let old_text = cx + .background_spawn({ + let old_snapshot = old_snapshot.clone(); + async move { Arc::new(old_snapshot.text()) } + }) + .await; - let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) { - edit_agent.edit( - buffer.clone(), - input.display_description.clone(), - &request, - cx, - ) - } else { - edit_agent.overwrite( - buffer.clone(), - input.display_description.clone(), - &request, - cx, - ) - }; + let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) { + edit_agent.edit( + buffer.clone(), + input.display_description.clone(), + &request, + cx, + ) + } else { + edit_agent.overwrite( + buffer.clone(), + input.display_description.clone(), + &request, + cx, + ) + }; - let mut hallucinated_old_text = false; - let mut ambiguous_ranges = Vec::new(); - let mut emitted_location = false; - while let Some(event) = events.next().await { - match event { - EditAgentOutputEvent::Edited(range) => { - if !emitted_location { - let line = buffer.update(cx, |buffer, _cx| { - range.start.to_point(&buffer.snapshot()).row - }).ok(); - if let Some(abs_path) = abs_path.clone() { - event_stream.update_fields(ToolCallUpdateFields { - locations: Some(vec![ToolCallLocation { path: abs_path, line }]), - ..Default::default() - }); - } - emitted_location = true; + let mut hallucinated_old_text = false; + let mut ambiguous_ranges = Vec::new(); + let mut emitted_location = false; + while let Some(event) = events.next().await { + match event { + EditAgentOutputEvent::Edited(range) => { + if !emitted_location { + let line = buffer.update(cx, |buffer, _cx| { + range.start.to_point(&buffer.snapshot()).row + }).ok(); + if let Some(abs_path) = abs_path.clone() { + event_stream.update_fields(ToolCallUpdateFields { + locations: Some(vec![ToolCallLocation { path: abs_path, line }]), + ..Default::default() + }); } - }, - EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true, - EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges, - EditAgentOutputEvent::ResolvingEditRange(range) => { - diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?; - // if !emitted_location { - // let line = buffer.update(cx, |buffer, _cx| { - // range.start.to_point(&buffer.snapshot()).row - // }).ok(); - // if let Some(abs_path) = abs_path.clone() { - // event_stream.update_fields(ToolCallUpdateFields { - // locations: Some(vec![ToolCallLocation { path: abs_path, line }]), - // ..Default::default() - // }); - // } - // } + emitted_location = true; } + }, + EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true, + EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges, + EditAgentOutputEvent::ResolvingEditRange(range) => { + diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?; + // if !emitted_location { + // let line = buffer.update(cx, |buffer, _cx| { + // range.start.to_point(&buffer.snapshot()).row + // }).ok(); + // if let Some(abs_path) = abs_path.clone() { + // event_stream.update_fields(ToolCallUpdateFields { + // locations: Some(vec![ToolCallLocation { path: abs_path, line }]), + // ..Default::default() + // }); + // } + // } } } + } - // If format_on_save is enabled, format the buffer - let format_on_save_enabled = buffer - .read_with(cx, |buffer, cx| { - let settings = language_settings::language_settings( - buffer.language().map(|l| l.name()), - buffer.file(), - cx, - ); - settings.format_on_save != FormatOnSave::Off - }) - .unwrap_or(false); + // If format_on_save is enabled, format the buffer + let format_on_save_enabled = buffer + .read_with(cx, |buffer, cx| { + let settings = language_settings::language_settings( + buffer.language().map(|l| l.name()), + buffer.file(), + cx, + ); + settings.format_on_save != FormatOnSave::Off + }) + .unwrap_or(false); - let edit_agent_output = output.await?; - - if format_on_save_enabled { - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); - })?; - - let format_task = project.update(cx, |project, cx| { - project.format( - HashSet::from_iter([buffer.clone()]), - LspFormatTarget::Buffers, - false, // Don't push to history since the tool did it. - FormatTrigger::Save, - cx, - ) - })?; - format_task.await.log_err(); - } - - project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? - .await?; + let edit_agent_output = output.await?; + if format_on_save_enabled { action_log.update(cx, |log, cx| { log.buffer_edited(buffer.clone(), cx); })?; - let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let (new_text, unified_diff) = cx - .background_spawn({ - let new_snapshot = new_snapshot.clone(); - let old_text = old_text.clone(); - async move { - let new_text = new_snapshot.text(); - let diff = language::unified_diff(&old_text, &new_text); - (new_text, diff) - } - }) - .await; + let format_task = project.update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, // Don't push to history since the tool did it. + FormatTrigger::Save, + cx, + ) + })?; + format_task.await.log_err(); + } - let input_path = input.path.display(); - if unified_diff.is_empty() { - anyhow::ensure!( - !hallucinated_old_text, - formatdoc! {" - Some edits were produced but none of them could be applied. - Read the relevant sections of {input_path} again so that - I can perform the requested edits. - "} - ); - anyhow::ensure!( - ambiguous_ranges.is_empty(), - { - let line_numbers = ambiguous_ranges - .iter() - .map(|range| range.start.to_string()) - .collect::>() - .join(", "); - formatdoc! {" - matches more than one position in the file (lines: {line_numbers}). Read the - relevant sections of {input_path} again and extend so - that I can perform the requested edits. - "} - } - ); - } + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? + .await?; - Ok(EditFileToolOutput { - input_path: input.path, - new_text, - old_text, - diff: unified_diff, - edit_agent_output, + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + })?; + + let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let (new_text, unified_diff) = cx + .background_spawn({ + let new_snapshot = new_snapshot.clone(); + let old_text = old_text.clone(); + async move { + let new_text = new_snapshot.text(); + let diff = language::unified_diff(&old_text, &new_text); + (new_text, diff) + } }) - }.await; + .await; - // Always finalize the diff, regardless of whether the operation succeeded or failed - diff.update(cx, |diff, cx| diff.finalize(cx)).ok(); + let input_path = input.path.display(); + if unified_diff.is_empty() { + anyhow::ensure!( + !hallucinated_old_text, + formatdoc! {" + Some edits were produced but none of them could be applied. + Read the relevant sections of {input_path} again so that + I can perform the requested edits. + "} + ); + anyhow::ensure!( + ambiguous_ranges.is_empty(), + { + let line_numbers = ambiguous_ranges + .iter() + .map(|range| range.start.to_string()) + .collect::>() + .join(", "); + formatdoc! {" + matches more than one position in the file (lines: {line_numbers}). Read the + relevant sections of {input_path} again and extend so + that I can perform the requested edits. + "} + } + ); + } - result + Ok(EditFileToolOutput { + input_path: input.path, + new_text, + old_text, + diff: unified_diff, + edit_agent_output, + }) }) } @@ -1550,6 +1550,76 @@ mod tests { ); } + #[gpui::test] + async fn test_diff_finalization(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/", json!({"main.rs": ""})).await; + + let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await; + let languages = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ) + }); + + // Ensure the diff is finalized if an error occurs while editing. + { + model.forbid_requests(); + let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone())); + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let edit = cx.update(|cx| { + tool.run( + EditFileToolInput { + display_description: "Edit file".into(), + path: "/main.rs".into(), + mode: EditFileMode::Edit, + }, + stream_tx, + cx, + ) + }); + stream_rx.expect_update_fields().await; + let diff = stream_rx.expect_diff().await; + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); + edit.await.unwrap_err(); + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + model.allow_requests(); + } + + // Ensure the diff is finalized if the tool call gets dropped. + { + let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone())); + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let edit = cx.update(|cx| { + tool.run( + EditFileToolInput { + display_description: "Edit file".into(), + path: "/main.rs".into(), + mode: EditFileMode::Edit, + }, + stream_tx, + cx, + ) + }); + stream_rx.expect_update_fields().await; + let diff = stream_rx.expect_diff().await; + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); + drop(edit); + cx.run_until_parked(); + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + } + } + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index ebfd37d16c..b06a475f93 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -4,12 +4,16 @@ use crate::{ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, }; +use anyhow::anyhow; use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, Entity, Task, Window}; use http_client::Result; use parking_lot::Mutex; use smol::stream::StreamExt; -use std::sync::Arc; +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering::SeqCst}, +}; #[derive(Clone)] pub struct FakeLanguageModelProvider { @@ -106,6 +110,7 @@ pub struct FakeLanguageModel { >, )>, >, + forbid_requests: AtomicBool, } impl Default for FakeLanguageModel { @@ -114,11 +119,20 @@ impl Default for FakeLanguageModel { provider_id: LanguageModelProviderId::from("fake".to_string()), provider_name: LanguageModelProviderName::from("Fake".to_string()), current_completion_txs: Mutex::new(Vec::new()), + forbid_requests: AtomicBool::new(false), } } } impl FakeLanguageModel { + pub fn allow_requests(&self) { + self.forbid_requests.store(false, SeqCst); + } + + pub fn forbid_requests(&self) { + self.forbid_requests.store(true, SeqCst); + } + pub fn pending_completions(&self) -> Vec { self.current_completion_txs .lock() @@ -251,9 +265,18 @@ impl LanguageModel for FakeLanguageModel { LanguageModelCompletionError, >, > { - let (tx, rx) = mpsc::unbounded(); - self.current_completion_txs.lock().push((request, tx)); - async move { Ok(rx.boxed()) }.boxed() + if self.forbid_requests.load(SeqCst) { + async move { + Err(LanguageModelCompletionError::Other(anyhow!( + "requests are forbidden" + ))) + } + .boxed() + } else { + let (tx, rx) = mpsc::unbounded(); + self.current_completion_txs.lock().push((request, tx)); + async move { Ok(rx.boxed()) }.boxed() + } } fn as_fake(&self) -> &Self {