Finalize diff on task drop
This commit is contained in:
parent
afd2c3aa76
commit
0fb6c1ee09
3 changed files with 259 additions and 142 deletions
|
@ -2459,6 +2459,30 @@ impl ToolCallEventStreamReceiver {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
|
||||
let event = self.0.next().await;
|
||||
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
||||
update,
|
||||
)))) = event
|
||||
{
|
||||
update.fields
|
||||
} else {
|
||||
panic!("Expected update fields but got: {:?}", event);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
|
||||
let event = self.0.next().await;
|
||||
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
|
||||
update,
|
||||
)))) = event
|
||||
{
|
||||
update.diff
|
||||
} else {
|
||||
panic!("Expected diff but got: {:?}", event);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
|
||||
let event = self.0.next().await;
|
||||
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
|
||||
|
|
|
@ -272,164 +272,164 @@ impl AgentTool for EditFileTool {
|
|||
.await?;
|
||||
|
||||
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
|
||||
let result = async {
|
||||
event_stream.update_diff(diff.clone());
|
||||
event_stream.update_diff(diff.clone());
|
||||
let _finalize_diff = util::defer({
|
||||
let diff = diff.downgrade();
|
||||
let mut cx = cx.clone();
|
||||
move || {
|
||||
diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
|
||||
}
|
||||
});
|
||||
|
||||
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let old_text = cx
|
||||
.background_spawn({
|
||||
let old_snapshot = old_snapshot.clone();
|
||||
async move { Arc::new(old_snapshot.text()) }
|
||||
})
|
||||
.await;
|
||||
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let old_text = cx
|
||||
.background_spawn({
|
||||
let old_snapshot = old_snapshot.clone();
|
||||
async move { Arc::new(old_snapshot.text()) }
|
||||
})
|
||||
.await;
|
||||
|
||||
|
||||
let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
|
||||
edit_agent.edit(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
&request,
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
edit_agent.overwrite(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
&request,
|
||||
cx,
|
||||
)
|
||||
};
|
||||
let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
|
||||
edit_agent.edit(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
&request,
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
edit_agent.overwrite(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
&request,
|
||||
cx,
|
||||
)
|
||||
};
|
||||
|
||||
let mut hallucinated_old_text = false;
|
||||
let mut ambiguous_ranges = Vec::new();
|
||||
let mut emitted_location = false;
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
EditAgentOutputEvent::Edited(range) => {
|
||||
if !emitted_location {
|
||||
let line = buffer.update(cx, |buffer, _cx| {
|
||||
range.start.to_point(&buffer.snapshot()).row
|
||||
}).ok();
|
||||
if let Some(abs_path) = abs_path.clone() {
|
||||
event_stream.update_fields(ToolCallUpdateFields {
|
||||
locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
emitted_location = true;
|
||||
let mut hallucinated_old_text = false;
|
||||
let mut ambiguous_ranges = Vec::new();
|
||||
let mut emitted_location = false;
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
EditAgentOutputEvent::Edited(range) => {
|
||||
if !emitted_location {
|
||||
let line = buffer.update(cx, |buffer, _cx| {
|
||||
range.start.to_point(&buffer.snapshot()).row
|
||||
}).ok();
|
||||
if let Some(abs_path) = abs_path.clone() {
|
||||
event_stream.update_fields(ToolCallUpdateFields {
|
||||
locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
},
|
||||
EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
|
||||
EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
|
||||
EditAgentOutputEvent::ResolvingEditRange(range) => {
|
||||
diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
|
||||
// if !emitted_location {
|
||||
// let line = buffer.update(cx, |buffer, _cx| {
|
||||
// range.start.to_point(&buffer.snapshot()).row
|
||||
// }).ok();
|
||||
// if let Some(abs_path) = abs_path.clone() {
|
||||
// event_stream.update_fields(ToolCallUpdateFields {
|
||||
// locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
|
||||
// ..Default::default()
|
||||
// });
|
||||
// }
|
||||
// }
|
||||
emitted_location = true;
|
||||
}
|
||||
},
|
||||
EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
|
||||
EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
|
||||
EditAgentOutputEvent::ResolvingEditRange(range) => {
|
||||
diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
|
||||
// if !emitted_location {
|
||||
// let line = buffer.update(cx, |buffer, _cx| {
|
||||
// range.start.to_point(&buffer.snapshot()).row
|
||||
// }).ok();
|
||||
// if let Some(abs_path) = abs_path.clone() {
|
||||
// event_stream.update_fields(ToolCallUpdateFields {
|
||||
// locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
|
||||
// ..Default::default()
|
||||
// });
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If format_on_save is enabled, format the buffer
|
||||
let format_on_save_enabled = buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let settings = language_settings::language_settings(
|
||||
buffer.language().map(|l| l.name()),
|
||||
buffer.file(),
|
||||
cx,
|
||||
);
|
||||
settings.format_on_save != FormatOnSave::Off
|
||||
})
|
||||
.unwrap_or(false);
|
||||
// If format_on_save is enabled, format the buffer
|
||||
let format_on_save_enabled = buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
let settings = language_settings::language_settings(
|
||||
buffer.language().map(|l| l.name()),
|
||||
buffer.file(),
|
||||
cx,
|
||||
);
|
||||
settings.format_on_save != FormatOnSave::Off
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
let edit_agent_output = output.await?;
|
||||
|
||||
if format_on_save_enabled {
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_edited(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let format_task = project.update(cx, |project, cx| {
|
||||
project.format(
|
||||
HashSet::from_iter([buffer.clone()]),
|
||||
LspFormatTarget::Buffers,
|
||||
false, // Don't push to history since the tool did it.
|
||||
FormatTrigger::Save,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
format_task.await.log_err();
|
||||
}
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.await?;
|
||||
let edit_agent_output = output.await?;
|
||||
|
||||
if format_on_save_enabled {
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_edited(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let (new_text, unified_diff) = cx
|
||||
.background_spawn({
|
||||
let new_snapshot = new_snapshot.clone();
|
||||
let old_text = old_text.clone();
|
||||
async move {
|
||||
let new_text = new_snapshot.text();
|
||||
let diff = language::unified_diff(&old_text, &new_text);
|
||||
(new_text, diff)
|
||||
}
|
||||
})
|
||||
.await;
|
||||
let format_task = project.update(cx, |project, cx| {
|
||||
project.format(
|
||||
HashSet::from_iter([buffer.clone()]),
|
||||
LspFormatTarget::Buffers,
|
||||
false, // Don't push to history since the tool did it.
|
||||
FormatTrigger::Save,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
format_task.await.log_err();
|
||||
}
|
||||
|
||||
let input_path = input.path.display();
|
||||
if unified_diff.is_empty() {
|
||||
anyhow::ensure!(
|
||||
!hallucinated_old_text,
|
||||
formatdoc! {"
|
||||
Some edits were produced but none of them could be applied.
|
||||
Read the relevant sections of {input_path} again so that
|
||||
I can perform the requested edits.
|
||||
"}
|
||||
);
|
||||
anyhow::ensure!(
|
||||
ambiguous_ranges.is_empty(),
|
||||
{
|
||||
let line_numbers = ambiguous_ranges
|
||||
.iter()
|
||||
.map(|range| range.start.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
formatdoc! {"
|
||||
<old_text> matches more than one position in the file (lines: {line_numbers}). Read the
|
||||
relevant sections of {input_path} again and extend <old_text> so
|
||||
that I can perform the requested edits.
|
||||
"}
|
||||
}
|
||||
);
|
||||
}
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.await?;
|
||||
|
||||
Ok(EditFileToolOutput {
|
||||
input_path: input.path,
|
||||
new_text,
|
||||
old_text,
|
||||
diff: unified_diff,
|
||||
edit_agent_output,
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_edited(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let (new_text, unified_diff) = cx
|
||||
.background_spawn({
|
||||
let new_snapshot = new_snapshot.clone();
|
||||
let old_text = old_text.clone();
|
||||
async move {
|
||||
let new_text = new_snapshot.text();
|
||||
let diff = language::unified_diff(&old_text, &new_text);
|
||||
(new_text, diff)
|
||||
}
|
||||
})
|
||||
}.await;
|
||||
.await;
|
||||
|
||||
// Always finalize the diff, regardless of whether the operation succeeded or failed
|
||||
diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
|
||||
let input_path = input.path.display();
|
||||
if unified_diff.is_empty() {
|
||||
anyhow::ensure!(
|
||||
!hallucinated_old_text,
|
||||
formatdoc! {"
|
||||
Some edits were produced but none of them could be applied.
|
||||
Read the relevant sections of {input_path} again so that
|
||||
I can perform the requested edits.
|
||||
"}
|
||||
);
|
||||
anyhow::ensure!(
|
||||
ambiguous_ranges.is_empty(),
|
||||
{
|
||||
let line_numbers = ambiguous_ranges
|
||||
.iter()
|
||||
.map(|range| range.start.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
formatdoc! {"
|
||||
<old_text> matches more than one position in the file (lines: {line_numbers}). Read the
|
||||
relevant sections of {input_path} again and extend <old_text> so
|
||||
that I can perform the requested edits.
|
||||
"}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
result
|
||||
Ok(EditFileToolOutput {
|
||||
input_path: input.path,
|
||||
new_text,
|
||||
old_text,
|
||||
diff: unified_diff,
|
||||
edit_agent_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1550,6 +1550,76 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_diff_finalization(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/", json!({"main.rs": ""})).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
|
||||
let languages = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
cx.new(|_cx| ProjectContext::default()),
|
||||
context_server_registry.clone(),
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Ensure the diff is finalized if an error occurs while editing.
|
||||
{
|
||||
model.forbid_requests();
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
|
||||
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
|
||||
let edit = cx.update(|cx| {
|
||||
tool.run(
|
||||
EditFileToolInput {
|
||||
display_description: "Edit file".into(),
|
||||
path: "/main.rs".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
},
|
||||
stream_tx,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
stream_rx.expect_update_fields().await;
|
||||
let diff = stream_rx.expect_diff().await;
|
||||
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
|
||||
edit.await.unwrap_err();
|
||||
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
|
||||
model.allow_requests();
|
||||
}
|
||||
|
||||
// Ensure the diff is finalized if the tool call gets dropped.
|
||||
{
|
||||
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
|
||||
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
|
||||
let edit = cx.update(|cx| {
|
||||
tool.run(
|
||||
EditFileToolInput {
|
||||
display_description: "Edit file".into(),
|
||||
path: "/main.rs".into(),
|
||||
mode: EditFileMode::Edit,
|
||||
},
|
||||
stream_tx,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
stream_rx.expect_update_fields().await;
|
||||
let diff = stream_rx.expect_diff().await;
|
||||
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
|
||||
drop(edit);
|
||||
cx.run_until_parked();
|
||||
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
|
||||
}
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
|
|
|
@ -4,12 +4,16 @@ use crate::{
|
|||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelToolChoice,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
||||
use http_client::Result;
|
||||
use parking_lot::Mutex;
|
||||
use smol::stream::StreamExt;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FakeLanguageModelProvider {
|
||||
|
@ -106,6 +110,7 @@ pub struct FakeLanguageModel {
|
|||
>,
|
||||
)>,
|
||||
>,
|
||||
forbid_requests: AtomicBool,
|
||||
}
|
||||
|
||||
impl Default for FakeLanguageModel {
|
||||
|
@ -114,11 +119,20 @@ impl Default for FakeLanguageModel {
|
|||
provider_id: LanguageModelProviderId::from("fake".to_string()),
|
||||
provider_name: LanguageModelProviderName::from("Fake".to_string()),
|
||||
current_completion_txs: Mutex::new(Vec::new()),
|
||||
forbid_requests: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeLanguageModel {
|
||||
pub fn allow_requests(&self) {
|
||||
self.forbid_requests.store(false, SeqCst);
|
||||
}
|
||||
|
||||
pub fn forbid_requests(&self) {
|
||||
self.forbid_requests.store(true, SeqCst);
|
||||
}
|
||||
|
||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||
self.current_completion_txs
|
||||
.lock()
|
||||
|
@ -251,9 +265,18 @@ impl LanguageModel for FakeLanguageModel {
|
|||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs.lock().push((request, tx));
|
||||
async move { Ok(rx.boxed()) }.boxed()
|
||||
if self.forbid_requests.load(SeqCst) {
|
||||
async move {
|
||||
Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"requests are forbidden"
|
||||
)))
|
||||
}
|
||||
.boxed()
|
||||
} else {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs.lock().push((request, tx));
|
||||
async move { Ok(rx.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
fn as_fake(&self) -> &Self {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue