diff --git a/Cargo.lock b/Cargo.lock index 64470b5abe..45684b8920 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,7 +20,9 @@ dependencies = [ "itertools 0.14.0", "language", "markdown", + "parking_lot", "project", + "rand 0.8.5", "serde", "serde_json", "settings", diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index 011f26f364..cd7a5c3808 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -41,7 +41,9 @@ async-pipe.workspace = true env_logger.workspace = true gpui = { workspace = true, "features" = ["test-support"] } indoc.workspace = true +parking_lot.workspace = true project = { workspace = true, "features" = ["test-support"] } +rand.workspace = true tempfile.workspace = true util.workspace = true settings.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 7a10f3bd72..0996dee723 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -671,7 +671,18 @@ impl AcpThread { for entry in self.entries.iter().rev() { match entry { AgentThreadEntry::UserMessage(_) => return false, - AgentThreadEntry::ToolCall(call) if call.diffs().next().is_some() => return true, + AgentThreadEntry::ToolCall( + call @ ToolCall { + status: + ToolCallStatus::Allowed { + status: + acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending, + }, + .. + }, + ) if call.diffs().next().is_some() => { + return true; + } AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {} } } @@ -1231,10 +1242,15 @@ mod tests { use agentic_coding_protocol as acp_old; use anyhow::anyhow; use async_pipe::{PipeReader, PipeWriter}; - use futures::{channel::mpsc, future::LocalBoxFuture, select}; - use gpui::{AsyncApp, TestAppContext}; + use futures::{ + channel::mpsc, + future::{LocalBoxFuture, try_join_all}, + select, + }; + use gpui::{AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; use project::FakeFs; + use rand::Rng as _; use serde_json::json; use settings::SettingsStore; use smol::{future::BoxedLocal, stream::StreamExt as _}; @@ -1562,6 +1578,42 @@ mod tests { }); } + #[gpui::test] + async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + + let connection = Rc::new(StubAgentConnection::new(vec![ + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("test".into()), + label: "Label".into(), + kind: acp::ToolKind::Edit, + status: acp::ToolCallStatus::Completed, + content: vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: "/test/test.txt".into(), + old_text: None, + new_text: "foo".into(), + }, + }], + locations: vec![], + raw_input: None, + }), + ])); + + let thread = connection + .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + .await + .unwrap(); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) + .await + .unwrap(); + + assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); + } + async fn run_until_first_tool_call( thread: &Entity, cx: &mut TestAppContext, @@ -1589,6 +1641,96 @@ mod tests { } } + #[derive(Clone, Default)] + struct StubAgentConnection { + sessions: Arc>>>, + permission_requests: HashMap>, + updates: Vec, + } + + impl StubAgentConnection { + fn new(updates: Vec) -> Self { + Self { + updates, + permission_requests: HashMap::default(), + sessions: Arc::default(), + } + } + } + + impl AgentConnection for StubAgentConnection { + fn name(&self) -> &'static str { + "StubAgentConnection" + } + + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::AsyncApp, + ) -> Task>> { + let session_id = acp::SessionId( + rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .into(), + ); + let thread = cx + .new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx)) + .unwrap(); + self.sessions.lock().insert(session_id, thread.downgrade()); + Task::ready(Ok(thread)) + } + + fn authenticate(&self, _cx: &mut App) -> Task> { + unimplemented!() + } + + fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task> { + let sessions = self.sessions.lock(); + let thread = sessions.get(¶ms.session_id).unwrap(); + let mut tasks = vec![]; + for update in &self.updates { + let thread = thread.clone(); + let update = update.clone(); + let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update + && let Some(options) = self.permission_requests.get(&tool_call.id) + { + Some((tool_call.clone(), options.clone())) + } else { + None + }; + let task = cx.spawn(async move |cx| { + if let Some((tool_call, options)) = permission_request { + let permission = thread.update(cx, |thread, cx| { + thread.request_tool_call_permission( + tool_call.clone(), + options.clone(), + cx, + ) + })?; + permission.await?; + } + thread.update(cx, |thread, cx| { + thread.handle_session_update(update.clone(), cx).unwrap(); + })?; + anyhow::Ok(()) + }); + tasks.push(task); + } + cx.spawn(async move |_| { + try_join_all(tasks).await?; + Ok(()) + }) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + } + pub fn fake_acp_thread( project: Entity, cx: &mut TestAppContext,