agent2: Always finalize diffs from the edit tool (#36918)
Previously, we wouldn't finalize the diff if an error occurred during editing or the tool call was canceled. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
c14d84cfdb
commit
b249593abe
3 changed files with 152 additions and 6 deletions
|
@ -2459,6 +2459,30 @@ impl ToolCallEventStreamReceiver {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
|
||||||
|
let event = self.0.next().await;
|
||||||
|
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
|
||||||
|
update,
|
||||||
|
)))) = event
|
||||||
|
{
|
||||||
|
update.fields
|
||||||
|
} else {
|
||||||
|
panic!("Expected update fields but got: {:?}", event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
|
||||||
|
let event = self.0.next().await;
|
||||||
|
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
|
||||||
|
update,
|
||||||
|
)))) = event
|
||||||
|
{
|
||||||
|
update.diff
|
||||||
|
} else {
|
||||||
|
panic!("Expected diff but got: {:?}", event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
|
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
|
||||||
let event = self.0.next().await;
|
let event = self.0.next().await;
|
||||||
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
|
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
|
||||||
|
|
|
@ -273,6 +273,13 @@ impl AgentTool for EditFileTool {
|
||||||
|
|
||||||
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
|
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
|
||||||
event_stream.update_diff(diff.clone());
|
event_stream.update_diff(diff.clone());
|
||||||
|
let _finalize_diff = util::defer({
|
||||||
|
let diff = diff.downgrade();
|
||||||
|
let mut cx = cx.clone();
|
||||||
|
move || {
|
||||||
|
diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||||
let old_text = cx
|
let old_text = cx
|
||||||
|
@ -389,8 +396,6 @@ impl AgentTool for EditFileTool {
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
|
|
||||||
|
|
||||||
let input_path = input.path.display();
|
let input_path = input.path.display();
|
||||||
if unified_diff.is_empty() {
|
if unified_diff.is_empty() {
|
||||||
anyhow::ensure!(
|
anyhow::ensure!(
|
||||||
|
@ -1545,6 +1550,100 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_diff_finalization(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
let fs = project::FakeFs::new(cx.executor());
|
||||||
|
fs.insert_tree("/", json!({"main.rs": ""})).await;
|
||||||
|
|
||||||
|
let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
|
||||||
|
let languages = project.read_with(cx, |project, _cx| project.languages().clone());
|
||||||
|
let context_server_registry =
|
||||||
|
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||||
|
let model = Arc::new(FakeLanguageModel::default());
|
||||||
|
let thread = cx.new(|cx| {
|
||||||
|
Thread::new(
|
||||||
|
project.clone(),
|
||||||
|
cx.new(|_cx| ProjectContext::default()),
|
||||||
|
context_server_registry.clone(),
|
||||||
|
Templates::new(),
|
||||||
|
Some(model.clone()),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Ensure the diff is finalized after the edit completes.
|
||||||
|
{
|
||||||
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
|
||||||
|
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
|
||||||
|
let edit = cx.update(|cx| {
|
||||||
|
tool.run(
|
||||||
|
EditFileToolInput {
|
||||||
|
display_description: "Edit file".into(),
|
||||||
|
path: path!("/main.rs").into(),
|
||||||
|
mode: EditFileMode::Edit,
|
||||||
|
},
|
||||||
|
stream_tx,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
stream_rx.expect_update_fields().await;
|
||||||
|
let diff = stream_rx.expect_diff().await;
|
||||||
|
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
|
||||||
|
cx.run_until_parked();
|
||||||
|
model.end_last_completion_stream();
|
||||||
|
edit.await.unwrap();
|
||||||
|
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the diff is finalized if an error occurs while editing.
|
||||||
|
{
|
||||||
|
model.forbid_requests();
|
||||||
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
|
||||||
|
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
|
||||||
|
let edit = cx.update(|cx| {
|
||||||
|
tool.run(
|
||||||
|
EditFileToolInput {
|
||||||
|
display_description: "Edit file".into(),
|
||||||
|
path: path!("/main.rs").into(),
|
||||||
|
mode: EditFileMode::Edit,
|
||||||
|
},
|
||||||
|
stream_tx,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
stream_rx.expect_update_fields().await;
|
||||||
|
let diff = stream_rx.expect_diff().await;
|
||||||
|
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
|
||||||
|
edit.await.unwrap_err();
|
||||||
|
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
|
||||||
|
model.allow_requests();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the diff is finalized if the tool call gets dropped.
|
||||||
|
{
|
||||||
|
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
|
||||||
|
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
|
||||||
|
let edit = cx.update(|cx| {
|
||||||
|
tool.run(
|
||||||
|
EditFileToolInput {
|
||||||
|
display_description: "Edit file".into(),
|
||||||
|
path: path!("/main.rs").into(),
|
||||||
|
mode: EditFileMode::Edit,
|
||||||
|
},
|
||||||
|
stream_tx,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
stream_rx.expect_update_fields().await;
|
||||||
|
let diff = stream_rx.expect_diff().await;
|
||||||
|
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
|
||||||
|
drop(edit);
|
||||||
|
cx.run_until_parked();
|
||||||
|
diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn init_test(cx: &mut TestAppContext) {
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
cx.update(|cx| {
|
cx.update(|cx| {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
|
|
|
@ -4,12 +4,16 @@ use crate::{
|
||||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
LanguageModelRequest, LanguageModelToolChoice,
|
LanguageModelRequest, LanguageModelToolChoice,
|
||||||
};
|
};
|
||||||
|
use anyhow::anyhow;
|
||||||
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
||||||
use http_client::Result;
|
use http_client::Result;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
use std::sync::Arc;
|
use std::sync::{
|
||||||
|
Arc,
|
||||||
|
atomic::{AtomicBool, Ordering::SeqCst},
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct FakeLanguageModelProvider {
|
pub struct FakeLanguageModelProvider {
|
||||||
|
@ -106,6 +110,7 @@ pub struct FakeLanguageModel {
|
||||||
>,
|
>,
|
||||||
)>,
|
)>,
|
||||||
>,
|
>,
|
||||||
|
forbid_requests: AtomicBool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for FakeLanguageModel {
|
impl Default for FakeLanguageModel {
|
||||||
|
@ -114,11 +119,20 @@ impl Default for FakeLanguageModel {
|
||||||
provider_id: LanguageModelProviderId::from("fake".to_string()),
|
provider_id: LanguageModelProviderId::from("fake".to_string()),
|
||||||
provider_name: LanguageModelProviderName::from("Fake".to_string()),
|
provider_name: LanguageModelProviderName::from("Fake".to_string()),
|
||||||
current_completion_txs: Mutex::new(Vec::new()),
|
current_completion_txs: Mutex::new(Vec::new()),
|
||||||
|
forbid_requests: AtomicBool::new(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeLanguageModel {
|
impl FakeLanguageModel {
|
||||||
|
pub fn allow_requests(&self) {
|
||||||
|
self.forbid_requests.store(false, SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forbid_requests(&self) {
|
||||||
|
self.forbid_requests.store(true, SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
|
||||||
self.current_completion_txs
|
self.current_completion_txs
|
||||||
.lock()
|
.lock()
|
||||||
|
@ -251,9 +265,18 @@ impl LanguageModel for FakeLanguageModel {
|
||||||
LanguageModelCompletionError,
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let (tx, rx) = mpsc::unbounded();
|
if self.forbid_requests.load(SeqCst) {
|
||||||
self.current_completion_txs.lock().push((request, tx));
|
async move {
|
||||||
async move { Ok(rx.boxed()) }.boxed()
|
Err(LanguageModelCompletionError::Other(anyhow!(
|
||||||
|
"requests are forbidden"
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
} else {
|
||||||
|
let (tx, rx) = mpsc::unbounded();
|
||||||
|
self.current_completion_txs.lock().push((request, tx));
|
||||||
|
async move { Ok(rx.boxed()) }.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_fake(&self) -> &Self {
|
fn as_fake(&self) -> &Self {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue