diff --git a/Cargo.lock b/Cargo.lock index d31189fa06..9078b32f7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,6 +31,7 @@ dependencies = [ "ui", "url", "util", + "uuid", "watch", "workspace-hack", ] @@ -6446,6 +6447,7 @@ dependencies = [ "log", "parking_lot", "pretty_assertions", + "rand 0.8.5", "regex", "rope", "schemars", diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index fd01b31786..b3ec217bad 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -36,6 +36,7 @@ terminal.workspace = true ui.workspace = true url.workspace = true util.workspace = true +uuid.workspace = true watch.workspace = true workspace-hack.workspace = true diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d1957e1c2a..f8a5bf8032 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -9,18 +9,19 @@ pub use mention::*; pub use terminal::*; use action_log::ActionLog; -use agent_client_protocol::{self as acp}; -use anyhow::{Context as _, Result}; +use agent_client_protocol as acp; +use anyhow::{Context as _, Result, anyhow}; use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use itertools::Itertools; use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff}; use markdown::Markdown; -use project::{AgentLocation, Project}; +use project::{AgentLocation, Project, git_store::GitStoreCheckpoint}; use std::collections::HashMap; use std::error::Error; -use std::fmt::Formatter; +use std::fmt::{Formatter, Write}; +use std::ops::Range; use std::process::ExitStatus; use std::rc::Rc; use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; @@ -29,24 +30,23 @@ use util::ResultExt; #[derive(Debug)] pub struct UserMessage { + pub id: Option, pub content: ContentBlock, + pub checkpoint: Option, } impl UserMessage { - pub fn from_acp( - message: impl IntoIterator, - language_registry: Arc, - cx: &mut App, - ) -> Self { - let mut content = ContentBlock::Empty; - for chunk in message { - content.append(chunk, &language_registry, cx) - } - Self { content: content } - } - fn to_markdown(&self, cx: &App) -> String { - format!("## User\n\n{}\n\n", self.content.to_markdown(cx)) + let mut markdown = String::new(); + if let Some(_) = self.checkpoint { + writeln!(markdown, "## User (checkpoint)").unwrap(); + } else { + writeln!(markdown, "## User").unwrap(); + } + writeln!(markdown).unwrap(); + writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap(); + writeln!(markdown).unwrap(); + markdown } } @@ -633,6 +633,7 @@ pub struct AcpThread { pub enum AcpThreadEvent { NewEntry, EntryUpdated(usize), + EntriesRemoved(Range), ToolAuthorizationRequired, Stopped, Error, @@ -772,7 +773,7 @@ impl AcpThread { ) -> Result<()> { match update { acp::SessionUpdate::UserMessageChunk { content } => { - self.push_user_content_block(content, cx); + self.push_user_content_block(None, content, cx); } acp::SessionUpdate::AgentMessageChunk { content } => { self.push_assistant_content_block(content, false, cx); @@ -793,18 +794,32 @@ impl AcpThread { Ok(()) } - pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context) { + pub fn push_user_content_block( + &mut self, + message_id: Option, + chunk: acp::ContentBlock, + cx: &mut Context, + ) { let language_registry = self.project.read(cx).languages().clone(); let entries_len = self.entries.len(); if let Some(last_entry) = self.entries.last_mut() - && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry + && let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry { + *id = message_id.or(id.take()); content.append(chunk, &language_registry, cx); - cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + let idx = entries_len - 1; + cx.emit(AcpThreadEvent::EntryUpdated(idx)); } else { let content = ContentBlock::new(chunk, &language_registry, cx); - self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx); + self.push_entry( + AgentThreadEntry::UserMessage(UserMessage { + id: message_id, + content, + checkpoint: None, + }), + cx, + ); } } @@ -819,7 +834,8 @@ impl AcpThread { if let Some(last_entry) = self.entries.last_mut() && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry { - cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1)); + let idx = entries_len - 1; + cx.emit(AcpThreadEvent::EntryUpdated(idx)); match (chunks.last_mut(), is_thought) { (Some(AssistantMessageChunk::Message { block }), false) | (Some(AssistantMessageChunk::Thought { block }), true) => { @@ -1118,69 +1134,113 @@ impl AcpThread { self.project.read(cx).languages().clone(), cx, ); + let git_store = self.project.read(cx).git_store().clone(); + + let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx)); + let message_id = if self + .connection + .session_editor(&self.session_id, cx) + .is_some() + { + Some(UserMessageId::new()) + } else { + None + }; self.push_entry( - AgentThreadEntry::UserMessage(UserMessage { content: block }), + AgentThreadEntry::UserMessage(UserMessage { + id: message_id.clone(), + content: block, + checkpoint: None, + }), cx, ); self.clear_completed_plan_entries(cx); + let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel(); let (tx, rx) = oneshot::channel(); let cancel_task = self.cancel(cx); + let request = acp::PromptRequest { + prompt: message, + session_id: self.session_id.clone(), + }; - self.send_task = Some(cx.spawn(async move |this, cx| { - async { + self.send_task = Some(cx.spawn({ + let message_id = message_id.clone(); + async move |this, cx| { cancel_task.await; - let result = this - .update(cx, |this, cx| { - this.connection.prompt( - acp::PromptRequest { - prompt: message, - session_id: this.session_id.clone(), - }, - cx, - ) - })? - .await; - - tx.send(result).log_err(); - - anyhow::Ok(()) + old_checkpoint_tx.send(old_checkpoint.await).ok(); + if let Ok(result) = this.update(cx, |this, cx| { + this.connection.prompt(message_id, request, cx) + }) { + tx.send(result.await).log_err(); + } } - .await - .log_err(); })); - cx.spawn(async move |this, cx| match rx.await { - Ok(Err(e)) => { - this.update(cx, |this, cx| { - this.send_task.take(); - cx.emit(AcpThreadEvent::Error) - }) + cx.spawn(async move |this, cx| { + let old_checkpoint = old_checkpoint_rx + .await + .map_err(|_| anyhow!("send canceled")) + .flatten() + .context("failed to get old checkpoint") .log_err(); - Err(e)? - } - result => { - let cancelled = matches!( - result, - Ok(Ok(acp::PromptResponse { - stop_reason: acp::StopReason::Cancelled - })) - ); - // We only take the task if the current prompt wasn't cancelled. - // - // This prompt may have been cancelled because another one was sent - // while it was still generating. In these cases, dropping `send_task` - // would cause the next generation to be cancelled. - if !cancelled { - this.update(cx, |this, _cx| this.send_task.take()).ok(); - } + let response = rx.await; - this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped)) + if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) { + let new_checkpoint = git_store + .update(cx, |git, cx| git.checkpoint(cx))? + .await + .context("failed to get new checkpoint") .log_err(); - Ok(()) + if let Some(new_checkpoint) = new_checkpoint { + let equal = git_store + .update(cx, |git, cx| { + git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx) + })? + .await + .unwrap_or(true); + if !equal { + this.update(cx, |this, cx| { + if let Some((ix, message)) = this.user_message_mut(&message_id) { + message.checkpoint = Some(old_checkpoint); + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + } + })?; + } + } } + + this.update(cx, |this, cx| { + match response { + Ok(Err(e)) => { + this.send_task.take(); + cx.emit(AcpThreadEvent::Error); + Err(e) + } + result => { + let cancelled = matches!( + result, + Ok(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Cancelled + })) + ); + + // We only take the task if the current prompt wasn't cancelled. + // + // This prompt may have been cancelled because another one was sent + // while it was still generating. In these cases, dropping `send_task` + // would cause the next generation to be cancelled. + if !cancelled { + this.send_task.take(); + } + + cx.emit(AcpThreadEvent::Stopped); + Ok(()) + } + } + })? }) .boxed() } @@ -1212,6 +1272,66 @@ impl AcpThread { cx.foreground_executor().spawn(send_task) } + /// Rewinds this thread to before the entry at `index`, removing it and all + /// subsequent entries while reverting any changes made from that point. + pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context) -> Task> { + let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else { + return Task::ready(Err(anyhow!("not supported"))); + }; + let Some(message) = self.user_message(&id) else { + return Task::ready(Err(anyhow!("message not found"))); + }; + + let checkpoint = message.checkpoint.clone(); + + let git_store = self.project.read(cx).git_store().clone(); + cx.spawn(async move |this, cx| { + if let Some(checkpoint) = checkpoint { + git_store + .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))? + .await?; + } + + cx.update(|cx| session_editor.truncate(id.clone(), cx))? + .await?; + this.update(cx, |this, cx| { + if let Some((ix, _)) = this.user_message_mut(&id) { + let range = ix..this.entries.len(); + this.entries.truncate(ix); + cx.emit(AcpThreadEvent::EntriesRemoved(range)); + } + }) + }) + } + + fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> { + self.entries.iter().find_map(|entry| { + if let AgentThreadEntry::UserMessage(message) = entry { + if message.id.as_ref() == Some(&id) { + Some(message) + } else { + None + } + } else { + None + } + }) + } + + fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> { + self.entries.iter_mut().enumerate().find_map(|(ix, entry)| { + if let AgentThreadEntry::UserMessage(message) = entry { + if message.id.as_ref() == Some(&id) { + Some((ix, message)) + } else { + None + } + } else { + None + } + }) + } + pub fn read_text_file( &self, path: PathBuf, @@ -1414,13 +1534,18 @@ mod tests { use futures::{channel::mpsc, future::LocalBoxFuture, select}; use gpui::{AsyncApp, TestAppContext, WeakEntity}; use indoc::indoc; - use project::FakeFs; + use project::{FakeFs, Fs}; use rand::Rng as _; use serde_json::json; use settings::SettingsStore; use smol::stream::StreamExt as _; - use std::{cell::RefCell, path::Path, rc::Rc, time::Duration}; - + use std::{ + cell::RefCell, + path::Path, + rc::Rc, + sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst}, + time::Duration, + }; use util::path; fn init_test(cx: &mut TestAppContext) { @@ -1452,6 +1577,7 @@ mod tests { // Test creating a new user message thread.update(cx, |thread, cx| { thread.push_user_content_block( + None, acp::ContentBlock::Text(acp::TextContent { annotations: None, text: "Hello, ".to_string(), @@ -1463,6 +1589,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 1); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.id, None); assert_eq!(user_msg.content.to_markdown(cx), "Hello, "); } else { panic!("Expected UserMessage"); @@ -1470,8 +1597,10 @@ mod tests { }); // Test appending to existing user message + let message_1_id = UserMessageId::new(); thread.update(cx, |thread, cx| { thread.push_user_content_block( + Some(message_1_id.clone()), acp::ContentBlock::Text(acp::TextContent { annotations: None, text: "world!".to_string(), @@ -1483,6 +1612,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 1); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { + assert_eq!(user_msg.id, Some(message_1_id)); assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!"); } else { panic!("Expected UserMessage"); @@ -1501,8 +1631,10 @@ mod tests { ); }); + let message_2_id = UserMessageId::new(); thread.update(cx, |thread, cx| { thread.push_user_content_block( + Some(message_2_id.clone()), acp::ContentBlock::Text(acp::TextContent { annotations: None, text: "New user message".to_string(), @@ -1514,6 +1646,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 3); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] { + assert_eq!(user_msg.id, Some(message_2_id)); assert_eq!(user_msg.content.to_markdown(cx), "New user message"); } else { panic!("Expected UserMessage at index 2"); @@ -1830,6 +1963,180 @@ mod tests { assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); } + #[gpui::test(iterations = 10)] + async fn test_checkpoints(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/test"), + json!({ + ".git": {} + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await; + + let simulate_changes = Arc::new(AtomicBool::new(true)); + let next_filename = Arc::new(AtomicUsize::new(0)); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let simulate_changes = simulate_changes.clone(); + let next_filename = next_filename.clone(); + let fs = fs.clone(); + move |request, thread, mut cx| { + let fs = fs.clone(); + let simulate_changes = simulate_changes.clone(); + let next_filename = next_filename.clone(); + async move { + if simulate_changes.load(SeqCst) { + let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst)); + fs.write(Path::new(&filename), b"").await?; + } + + let acp::ContentBlock::Text(content) = &request.prompt[0] else { + panic!("expected text content block"); + }; + thread.update(&mut cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::AgentMessageChunk { + content: content.text.to_uppercase().into(), + }, + cx, + ) + .unwrap(); + })?; + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + } + })); + let thread = connection + .new_thread(project, Path::new(path!("/test")), &mut cx.to_async()) + .await + .unwrap(); + + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + "} + ); + }); + assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); + + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + ## User (checkpoint) + + ipsum + + ## Assistant + + IPSUM + + "} + ); + }); + assert_eq!( + fs.files(), + vec![Path::new("/test/file-0"), Path::new("/test/file-1")] + ); + + // Checkpoint isn't stored when there are no changes. + simulate_changes.store(false, SeqCst); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx))) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + ## User (checkpoint) + + ipsum + + ## Assistant + + IPSUM + + ## User + + dolor + + ## Assistant + + DOLOR + + "} + ); + }); + assert_eq!( + fs.files(), + vec![Path::new("/test/file-0"), Path::new("/test/file-1")] + ); + + // Rewinding the conversation truncates the history and restores the checkpoint. + thread + .update(cx, |thread, cx| { + let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else { + panic!("unexpected entries {:?}", thread.entries) + }; + thread.rewind(message.id.clone().unwrap(), cx) + }) + .await + .unwrap(); + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc! {" + ## User (checkpoint) + + Lorem + + ## Assistant + + LOREM + + "} + ); + }); + assert_eq!(fs.files(), vec![Path::new("/test/file-0")]); + } + async fn run_until_first_tool_call( thread: &Entity, cx: &mut TestAppContext, @@ -1938,6 +2245,7 @@ mod tests { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -1966,5 +2274,25 @@ mod tests { }) .detach(); } + + fn session_editor( + &self, + session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + Some(Rc::new(FakeAgentSessionEditor { + _session_id: session_id.clone(), + })) + } + } + + struct FakeAgentSessionEditor { + _session_id: acp::SessionId, + } + + impl AgentSessionEditor for FakeAgentSessionEditor { + fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 8e6294b3ce..c3167eb2d4 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,13 +1,21 @@ -use std::{error::Error, fmt, path::Path, rc::Rc}; - +use crate::AcpThread; use agent_client_protocol::{self as acp}; use anyhow::Result; use collections::IndexMap; use gpui::{AsyncApp, Entity, SharedString, Task}; use project::Project; +use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use ui::{App, IconName}; +use uuid::Uuid; -use crate::AcpThread; +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UserMessageId(Arc); + +impl UserMessageId { + pub fn new() -> Self { + Self(Uuid::new_v4().to_string().into()) + } +} pub trait AgentConnection { fn new_thread( @@ -21,11 +29,23 @@ pub trait AgentConnection { fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; - fn prompt(&self, params: acp::PromptRequest, cx: &mut App) - -> Task>; + fn prompt( + &self, + user_message_id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task>; fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); + fn session_editor( + &self, + _session_id: &acp::SessionId, + _cx: &mut App, + ) -> Option> { + None + } + /// Returns this agent as an [Rc] if the model selection capability is supported. /// /// If the agent does not support model selection, returns [None]. @@ -35,6 +55,10 @@ pub trait AgentConnection { } } +pub trait AgentSessionEditor { + fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task>; +} + #[derive(Debug)] pub struct AuthRequired; diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 3ddd7be793..ced8c5e401 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,8 +1,9 @@ use crate::{AgentResponseEvent, Thread, templates::Templates}; use crate::{ ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, - FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool, - OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool, + FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, + ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent, + WebSearchTool, }; use acp_thread::AgentModelSelector; use agent_client_protocol as acp; @@ -637,9 +638,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { fn prompt( &self, + id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { + let id = id.expect("UserMessageId is required"); let session_id = params.session_id.clone(); let agent = self.0.clone(); log::info!("Received prompt request for session: {}", session_id); @@ -660,13 +663,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection { })?; log::debug!("Found session for: {}", session_id); - let message: Vec = params + let content: Vec = params .prompt .into_iter() .map(Into::into) .collect::>(); - log::info!("Converted prompt to message: {} chars", message.len()); - log::debug!("Message content: {:?}", message); + log::info!("Converted prompt to message: {} chars", content.len()); + log::debug!("Message id: {:?}", id); + log::debug!("Message content: {:?}", content); // Get model using the ModelSelector capability (always available for agent2) // Get the selected model from the thread directly @@ -674,7 +678,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Send to thread log::info!("Sending message to thread with model: {:?}", model.name()); - let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?; + let mut response_stream = + thread.update(cx, |thread, cx| thread.send(id, content, cx))?; // Handle response stream and forward to session.acp_thread while let Some(result) = response_stream.next().await { @@ -768,6 +773,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection { } }); } + + fn session_editor( + &self, + session_id: &agent_client_protocol::SessionId, + cx: &mut App, + ) -> Option> { + self.0.update(cx, |agent, _cx| { + agent + .sessions + .get(session_id) + .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _) + }) + } +} + +struct NativeAgentSessionEditor(Entity); + +impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { + fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { + Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) + } } #[cfg(test)] diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index b70fa56747..637af73d1a 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,6 +1,5 @@ use super::*; -use crate::MessageContent; -use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList}; +use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId}; use action_log::ActionLog; use agent_client_protocol::{self as acp}; use agent_settings::AgentProfileId; @@ -38,15 +37,19 @@ async fn test_echo(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { - thread.send("Testing: Reply with 'Hello'", cx) + thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx) }) .collect() .await; thread.update(cx, |thread, _cx| { assert_eq!( - thread.messages().last().unwrap().content, - vec![MessageContent::Text("Hello".to_string())] - ); + thread.last_message().unwrap().to_markdown(), + indoc! {" + ## Assistant + + Hello + "} + ) }); assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); } @@ -59,12 +62,13 @@ async fn test_thinking(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { thread.send( - indoc! {" + UserMessageId::new(), + [indoc! {" Testing: Generate a thinking step where you just think the word 'Think', and have your final answer be 'Hello' - "}, + "}], cx, ) }) @@ -72,9 +76,10 @@ async fn test_thinking(cx: &mut TestAppContext) { .await; thread.update(cx, |thread, _cx| { assert_eq!( - thread.messages().last().unwrap().to_markdown(), + thread.last_message().unwrap().to_markdown(), indoc! {" - ## assistant + ## Assistant + Think Hello "} @@ -95,7 +100,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) { project_context.borrow_mut().shell = "test-shell".into(); thread.update(cx, |thread, _| thread.add_tool(EchoTool)); - thread.update(cx, |thread, cx| thread.send("abc", cx)); + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["abc"], cx) + }); cx.run_until_parked(); let mut pending_completions = fake_model.pending_completions(); assert_eq!( @@ -132,7 +139,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { .update(cx, |thread, cx| { thread.add_tool(EchoTool); thread.send( - "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.", + UserMessageId::new(), + ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."], cx, ) }) @@ -146,7 +154,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { thread.remove_tool(&AgentTool::name(&EchoTool)); thread.add_tool(DelayTool); thread.send( - "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.", + UserMessageId::new(), + [ + "Now call the delay tool with 200ms.", + "When the timer goes off, then you echo the output of the tool.", + ], cx, ) }) @@ -156,13 +168,14 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { thread.update(cx, |thread, _cx| { assert!( thread - .messages() - .last() + .last_message() + .unwrap() + .as_agent_message() .unwrap() .content .iter() .any(|content| { - if let MessageContent::Text(text) = content { + if let AgentMessageContent::Text(text) = content { text.contains("Ding") } else { false @@ -182,7 +195,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { // Test a tool call that's likely to complete *before* streaming stops. let mut events = thread.update(cx, |thread, cx| { thread.add_tool(WordListTool); - thread.send("Test the word_list tool.", cx) + thread.send(UserMessageId::new(), ["Test the word_list tool."], cx) }); let mut saw_partial_tool_use = false; @@ -190,8 +203,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { thread.update(cx, |thread, _cx| { // Look for a tool use in the thread's last message - let last_content = thread.messages().last().unwrap().content.last().unwrap(); - if let MessageContent::ToolUse(last_tool_use) = last_content { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); + let last_content = agent_message.content.last().unwrap(); + if let AgentMessageContent::ToolUse(last_tool_use) = last_content { assert_eq!(last_tool_use.name.as_ref(), "word_list"); if tool_call.status == acp::ToolCallStatus::Pending { if !last_tool_use.is_input_complete @@ -229,7 +244,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { let mut events = thread.update(cx, |thread, cx| { thread.add_tool(ToolRequiringPermission); - thread.send("abc", cx) + thread.send(UserMessageId::new(), ["abc"], cx) }); cx.run_until_parked(); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( @@ -357,7 +372,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx)); + let mut events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["abc"], cx) + }); cx.run_until_parked(); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { @@ -449,7 +466,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { .update(cx, |thread, cx| { thread.add_tool(DelayTool); thread.send( - "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.", + UserMessageId::new(), + [ + "Call the delay tool twice in the same message.", + "Once with 100ms. Once with 300ms.", + "When both timers are complete, describe the outputs.", + ], cx, ) }) @@ -460,12 +482,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]); thread.update(cx, |thread, _cx| { - let last_message = thread.messages().last().unwrap(); - let text = last_message + let last_message = thread.last_message().unwrap(); + let agent_message = last_message.as_agent_message().unwrap(); + let text = agent_message .content .iter() .filter_map(|content| { - if let MessageContent::Text(text) = content { + if let AgentMessageContent::Text(text) = content { Some(text.as_str()) } else { None @@ -521,7 +544,7 @@ async fn test_profiles(cx: &mut TestAppContext) { // Test that test-1 profile (default) has echo and delay tools thread.update(cx, |thread, cx| { thread.set_profile(AgentProfileId("test-1".into())); - thread.send("test", cx); + thread.send(UserMessageId::new(), ["test"], cx); }); cx.run_until_parked(); @@ -539,7 +562,7 @@ async fn test_profiles(cx: &mut TestAppContext) { // Switch to test-2 profile, and verify that it has only the infinite tool. thread.update(cx, |thread, cx| { thread.set_profile(AgentProfileId("test-2".into())); - thread.send("test2", cx) + thread.send(UserMessageId::new(), ["test2"], cx) }); cx.run_until_parked(); let mut pending_completions = fake_model.pending_completions(); @@ -562,7 +585,8 @@ async fn test_cancellation(cx: &mut TestAppContext) { thread.add_tool(InfiniteTool); thread.add_tool(EchoTool); thread.send( - "Call the echo tool and then call the infinite tool, then explain their output", + UserMessageId::new(), + ["Call the echo tool, then call the infinite tool, then explain their output"], cx, ) }); @@ -607,14 +631,20 @@ async fn test_cancellation(cx: &mut TestAppContext) { // Ensure we can still send a new message after cancellation. let events = thread .update(cx, |thread, cx| { - thread.send("Testing: reply with 'Hello' then stop.", cx) + thread.send( + UserMessageId::new(), + ["Testing: reply with 'Hello' then stop."], + cx, + ) }) .collect::>() .await; thread.update(cx, |thread, _cx| { + let message = thread.last_message().unwrap(); + let agent_message = message.as_agent_message().unwrap(); assert_eq!( - thread.messages().last().unwrap().content, - vec![MessageContent::Text("Hello".to_string())] + agent_message.content, + vec![AgentMessageContent::Text("Hello".to_string())] ); }); assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); @@ -625,13 +655,16 @@ async fn test_refusal(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let events = thread.update(cx, |thread, cx| thread.send("Hello", cx)); + let events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello"], cx) + }); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!( thread.to_markdown(), indoc! {" - ## user + ## User + Hello "} ); @@ -643,9 +676,12 @@ async fn test_refusal(cx: &mut TestAppContext) { assert_eq!( thread.to_markdown(), indoc! {" - ## user + ## User + Hello - ## assistant + + ## Assistant + Hey! "} ); @@ -661,6 +697,85 @@ async fn test_refusal(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_truncate(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let message_id = UserMessageId::new(); + thread.update(cx, |thread, cx| { + thread.send(message_id.clone(), ["Hello"], cx) + }); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello + "} + ); + }); + + fake_model.send_last_completion_stream_text_chunk("Hey!"); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hello + + ## Assistant + + Hey! + "} + ); + }); + + thread + .update(cx, |thread, _cx| thread.truncate(message_id)) + .unwrap(); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!(thread.to_markdown(), ""); + }); + + // Ensure we can still send a new message after truncation. + thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hi"], cx) + }); + thread.update(cx, |thread, _cx| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hi + "} + ); + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Ahoy!"); + cx.run_until_parked(); + thread.read_with(cx, |thread, _| { + assert_eq!( + thread.to_markdown(), + indoc! {" + ## User + + Hi + + ## Assistant + + Ahoy! + "} + ); + }); +} + #[gpui::test] async fn test_agent_connection(cx: &mut TestAppContext) { cx.update(settings::init); @@ -774,6 +889,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let result = cx .update(|cx| { connection.prompt( + Some(acp_thread::UserMessageId::new()), acp::PromptRequest { session_id: session_id.clone(), prompt: vec!["ghi".into()], @@ -796,7 +912,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool)); let fake_model = model.as_fake(); - let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx)); + let mut events = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Think"], cx) + }); cx.run_until_parked(); // Simulate streaming partial input. diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 1a571e8009..204b489124 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,12 +1,12 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; -use acp_thread::MentionUri; +use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, AgentSettings}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; use cloud_llm_client::{CompletionIntent, CompletionMode}; -use collections::HashMap; +use collections::IndexMap; use fs::Fs; use futures::{ channel::{mpsc, oneshot}, @@ -19,7 +19,6 @@ use language_model::{ LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, }; -use log; use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; @@ -30,49 +29,199 @@ use std::fmt::Write; use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc}; use util::{ResultExt, markdown::MarkdownCodeBlock}; -#[derive(Debug, Clone)] -pub struct AgentMessage { - pub role: Role, - pub content: Vec, +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Message { + User(UserMessage), + Agent(AgentMessage), +} + +impl Message { + pub fn as_agent_message(&self) -> Option<&AgentMessage> { + match self { + Message::Agent(agent_message) => Some(agent_message), + _ => None, + } + } + + pub fn to_markdown(&self) -> String { + match self { + Message::User(message) => message.to_markdown(), + Message::Agent(message) => message.to_markdown(), + } + } } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum MessageContent { +pub struct UserMessage { + pub id: UserMessageId, + pub content: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum UserMessageContent { Text(String), - Thinking { - text: String, - signature: Option, - }, - Mention { - uri: MentionUri, - content: String, - }, - RedactedThinking(String), + Mention { uri: MentionUri, content: String }, Image(LanguageModelImage), - ToolUse(LanguageModelToolUse), - ToolResult(LanguageModelToolResult), +} + +impl UserMessage { + pub fn to_markdown(&self) -> String { + let mut markdown = String::from("## User\n\n"); + + for content in &self.content { + match content { + UserMessageContent::Text(text) => { + markdown.push_str(text); + markdown.push('\n'); + } + UserMessageContent::Image(_) => { + markdown.push_str("\n"); + } + UserMessageContent::Mention { uri, content } => { + if !content.is_empty() { + markdown.push_str(&format!("{}\n\n{}\n", uri.to_link(), content)); + } else { + markdown.push_str(&format!("{}\n", uri.to_link())); + } + } + } + } + + markdown + } + + fn to_request(&self) -> LanguageModelRequestMessage { + let mut message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; + + const OPEN_CONTEXT: &str = "\n\ + The following items were attached by the user. \ + They are up-to-date and don't need to be re-read.\n\n"; + + const OPEN_FILES_TAG: &str = ""; + const OPEN_SYMBOLS_TAG: &str = ""; + const OPEN_THREADS_TAG: &str = ""; + const OPEN_RULES_TAG: &str = + "\nThe user has specified the following rules that should be applied:\n"; + + let mut file_context = OPEN_FILES_TAG.to_string(); + let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); + let mut thread_context = OPEN_THREADS_TAG.to_string(); + let mut rules_context = OPEN_RULES_TAG.to_string(); + + for chunk in &self.content { + let chunk = match chunk { + UserMessageContent::Text(text) => { + language_model::MessageContent::Text(text.clone()) + } + UserMessageContent::Image(value) => { + language_model::MessageContent::Image(value.clone()) + } + UserMessageContent::Mention { uri, content } => { + match uri { + MentionUri::File(path) | MentionUri::Symbol(path, _) => { + write!( + &mut symbol_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(&path), + text: &content.to_string(), + } + ) + .ok(); + } + MentionUri::Thread(_session_id) => { + write!(&mut thread_context, "\n{}\n", content).ok(); + } + MentionUri::Rule(_user_prompt_id) => { + write!( + &mut rules_context, + "\n{}", + MarkdownCodeBlock { + tag: "", + text: &content + } + ) + .ok(); + } + } + + language_model::MessageContent::Text(uri.to_link()) + } + }; + + message.content.push(chunk); + } + + let len_before_context = message.content.len(); + + if file_context.len() > OPEN_FILES_TAG.len() { + file_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(file_context)); + } + + if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { + symbol_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(symbol_context)); + } + + if thread_context.len() > OPEN_THREADS_TAG.len() { + thread_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(thread_context)); + } + + if rules_context.len() > OPEN_RULES_TAG.len() { + rules_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(rules_context)); + } + + if message.content.len() > len_before_context { + message.content.insert( + len_before_context, + language_model::MessageContent::Text(OPEN_CONTEXT.into()), + ); + message + .content + .push(language_model::MessageContent::Text("".into())); + } + + message + } } impl AgentMessage { pub fn to_markdown(&self) -> String { - let mut markdown = format!("## {}\n", self.role); + let mut markdown = String::from("## Assistant\n\n"); for content in &self.content { match content { - MessageContent::Text(text) => { + AgentMessageContent::Text(text) => { markdown.push_str(text); markdown.push('\n'); } - MessageContent::Thinking { text, .. } => { + AgentMessageContent::Thinking { text, .. } => { markdown.push_str(""); markdown.push_str(text); markdown.push_str("\n"); } - MessageContent::RedactedThinking(_) => markdown.push_str("\n"), - MessageContent::Image(_) => { + AgentMessageContent::RedactedThinking(_) => { + markdown.push_str("\n") + } + AgentMessageContent::Image(_) => { markdown.push_str("\n"); } - MessageContent::ToolUse(tool_use) => { + AgentMessageContent::ToolUse(tool_use) => { markdown.push_str(&format!( "**Tool Use**: {} (ID: {})\n", tool_use.name, tool_use.id @@ -85,41 +234,106 @@ impl AgentMessage { } )); } - MessageContent::ToolResult(tool_result) => { - markdown.push_str(&format!( - "**Tool Result**: {} (ID: {})\n\n", - tool_result.tool_name, tool_result.tool_use_id - )); - if tool_result.is_error { - markdown.push_str("**ERROR:**\n"); - } + } + } - match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - writeln!(markdown, "{text}\n").ok(); - } - LanguageModelToolResultContent::Image(_) => { - writeln!(markdown, "\n").ok(); - } - } + for tool_result in self.tool_results.values() { + markdown.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + markdown.push_str("**ERROR:**\n"); + } - if let Some(output) = tool_result.output.as_ref() { - writeln!( - markdown, - "**Debug Output**:\n\n```json\n{}\n```\n", - serde_json::to_string_pretty(output).unwrap() - ) - .unwrap(); - } + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + writeln!(markdown, "{text}\n").ok(); } - MessageContent::Mention { uri, .. } => { - write!(markdown, "{}", uri.to_link()).ok(); + LanguageModelToolResultContent::Image(_) => { + writeln!(markdown, "\n").ok(); } } + + if let Some(output) = tool_result.output.as_ref() { + writeln!( + markdown, + "**Debug Output**:\n\n```json\n{}\n```\n", + serde_json::to_string_pretty(output).unwrap() + ) + .unwrap(); + } } markdown } + + pub fn to_request(&self) -> Vec { + let mut content = Vec::with_capacity(self.content.len()); + for chunk in &self.content { + let chunk = match chunk { + AgentMessageContent::Text(text) => { + language_model::MessageContent::Text(text.clone()) + } + AgentMessageContent::Thinking { text, signature } => { + language_model::MessageContent::Thinking { + text: text.clone(), + signature: signature.clone(), + } + } + AgentMessageContent::RedactedThinking(value) => { + language_model::MessageContent::RedactedThinking(value.clone()) + } + AgentMessageContent::ToolUse(value) => { + language_model::MessageContent::ToolUse(value.clone()) + } + AgentMessageContent::Image(value) => { + language_model::MessageContent::Image(value.clone()) + } + }; + content.push(chunk); + } + + let mut messages = vec![LanguageModelRequestMessage { + role: Role::Assistant, + content, + cache: false, + }]; + + if !self.tool_results.is_empty() { + let mut tool_results = Vec::with_capacity(self.tool_results.len()); + for tool_result in self.tool_results.values() { + tool_results.push(language_model::MessageContent::ToolResult( + tool_result.clone(), + )); + } + messages.push(LanguageModelRequestMessage { + role: Role::User, + content: tool_results, + cache: false, + }); + } + + messages + } +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct AgentMessage { + pub content: Vec, + pub tool_results: IndexMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AgentMessageContent { + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking(String), + Image(LanguageModelImage), + ToolUse(LanguageModelToolUse), } #[derive(Debug)] @@ -140,13 +354,13 @@ pub struct ToolCallAuthorization { } pub struct Thread { - messages: Vec, + messages: Vec, completion_mode: CompletionMode, /// Holds the task that handles agent interaction until the end of the turn. /// Survives across multiple requests as the model performs tool calls and /// we run tools, report their results. running_turn: Option>, - pending_tool_uses: HashMap, + pending_agent_message: Option, tools: BTreeMap>, context_server_registry: Entity, profile_id: AgentProfileId, @@ -172,7 +386,7 @@ impl Thread { messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, - pending_tool_uses: HashMap::default(), + pending_agent_message: None, tools: BTreeMap::default(), context_server_registry, profile_id, @@ -196,8 +410,13 @@ impl Thread { self.completion_mode = mode; } - pub fn messages(&self) -> &[AgentMessage] { - &self.messages + #[cfg(any(test, feature = "test-support"))] + pub fn last_message(&self) -> Option { + if let Some(message) = self.pending_agent_message.clone() { + Some(Message::Agent(message)) + } else { + self.messages.last().cloned() + } } pub fn add_tool(&mut self, tool: impl AgentTool) { @@ -213,35 +432,36 @@ impl Thread { } pub fn cancel(&mut self) { + // TODO: do we need to emit a stop::cancel for ACP? self.running_turn.take(); + self.flush_pending_agent_message(); + } - let tool_results = self - .pending_tool_uses - .drain() - .map(|(tool_use_id, tool_use)| { - MessageContent::ToolResult(LanguageModelToolResult { - tool_use_id, - tool_name: tool_use.name.clone(), - is_error: true, - content: LanguageModelToolResultContent::Text("Tool canceled by user".into()), - output: None, - }) - }) - .collect::>(); - self.last_user_message().content.extend(tool_results); + pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> { + self.cancel(); + let Some(position) = self.messages.iter().position( + |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), + ) else { + return Err(anyhow!("Message not found")); + }; + self.messages.truncate(position); + Ok(()) } /// Sending a message results in the model streaming a response, which could include tool calls. /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. - pub fn send( + pub fn send( &mut self, - content: impl Into, + message_id: UserMessageId, + content: impl IntoIterator, cx: &mut Context, - ) -> mpsc::UnboundedReceiver> { - let content = content.into().0; - + ) -> mpsc::UnboundedReceiver> + where + T: Into, + { let model = self.selected_model.clone(); + let content = content.into_iter().map(Into::into).collect::>(); log::info!("Thread::send called with model: {:?}", model.name()); log::debug!("Thread::send content: {:?}", content); @@ -251,10 +471,10 @@ impl Thread { let event_stream = AgentResponseEventStream(events_tx); let user_message_ix = self.messages.len(); - self.messages.push(AgentMessage { - role: Role::User, + self.messages.push(Message::User(UserMessage { + id: message_id, content, - }); + })); log::info!("Total messages in thread: {}", self.messages.len()); self.running_turn = Some(cx.spawn(async move |thread, cx| { log::info!("Starting agent turn execution"); @@ -270,15 +490,11 @@ impl Thread { thread.build_completion_request(completion_intent, cx) })?; - // println!( - // "request: {}", - // serde_json::to_string_pretty(&request).unwrap() - // ); - // Stream events, appending to messages and collecting up tool uses. log::info!("Calling model.stream_completion"); let mut events = model.stream_completion(request, cx).await?; log::debug!("Stream completion started successfully"); + let mut tool_uses = FuturesUnordered::new(); while let Some(event) = events.next().await { match event { @@ -286,6 +502,7 @@ impl Thread { event_stream.send_stop(reason); if reason == StopReason::Refusal { thread.update(cx, |thread, _cx| { + thread.pending_agent_message = None; thread.messages.truncate(user_message_ix); })?; break 'outer; @@ -338,15 +555,16 @@ impl Thread { ); thread .update(cx, |thread, _cx| { - thread.pending_tool_uses.remove(&tool_result.tool_use_id); thread - .last_user_message() - .content - .push(MessageContent::ToolResult(tool_result)); + .pending_agent_message() + .tool_results + .insert(tool_result.tool_use_id.clone(), tool_result); }) .ok(); } + thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?; + completion_intent = CompletionIntent::ToolResults; } @@ -354,6 +572,10 @@ impl Thread { } .await; + thread + .update(cx, |thread, _cx| thread.flush_pending_agent_message()) + .ok(); + if let Err(error) = turn_result { log::error!("Turn execution failed: {:?}", error); event_stream.send_error(error); @@ -364,7 +586,7 @@ impl Thread { events_rx } - pub fn build_system_message(&self) -> AgentMessage { + pub fn build_system_message(&self) -> LanguageModelRequestMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { project: &self.project_context.borrow(), @@ -374,9 +596,10 @@ impl Thread { .context("failed to build system prompt") .expect("Invalid template"); log::debug!("System message built"); - AgentMessage { + LanguageModelRequestMessage { role: Role::System, - content: vec![prompt.as_str().into()], + content: vec![prompt.into()], + cache: true, } } @@ -394,10 +617,7 @@ impl Thread { match event { StartMessage { .. } => { - self.messages.push(AgentMessage { - role: Role::Assistant, - content: Vec::new(), - }); + self.messages.push(Message::Agent(AgentMessage::default())); } Text(new_text) => self.handle_text_event(new_text, event_stream, cx), Thinking { text, signature } => { @@ -435,11 +655,13 @@ impl Thread { ) { events_stream.send_text(&new_text); - let last_message = self.last_assistant_message(); - if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { + let last_message = self.pending_agent_message(); + if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() { text.push_str(&new_text); } else { - last_message.content.push(MessageContent::Text(new_text)); + last_message + .content + .push(AgentMessageContent::Text(new_text)); } cx.notify(); @@ -454,13 +676,14 @@ impl Thread { ) { event_stream.send_thinking(&new_text); - let last_message = self.last_assistant_message(); - if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut() + let last_message = self.pending_agent_message(); + if let Some(AgentMessageContent::Thinking { text, signature }) = + last_message.content.last_mut() { text.push_str(&new_text); *signature = new_signature.or(signature.take()); } else { - last_message.content.push(MessageContent::Thinking { + last_message.content.push(AgentMessageContent::Thinking { text: new_text, signature: new_signature, }); @@ -470,10 +693,10 @@ impl Thread { } fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context) { - let last_message = self.last_assistant_message(); + let last_message = self.pending_agent_message(); last_message .content - .push(MessageContent::RedactedThinking(data)); + .push(AgentMessageContent::RedactedThinking(data)); cx.notify(); } @@ -486,14 +709,17 @@ impl Thread { cx.notify(); let tool = self.tools.get(tool_use.name.as_ref()).cloned(); - - self.pending_tool_uses - .insert(tool_use.id.clone(), tool_use.clone()); - let last_message = self.last_assistant_message(); + let mut title = SharedString::from(&tool_use.name); + let mut kind = acp::ToolKind::Other; + if let Some(tool) = tool.as_ref() { + title = tool.initial_title(tool_use.input.clone()); + kind = tool.kind(); + } // Ensure the last message ends in the current tool use + let last_message = self.pending_agent_message(); let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| { - if let MessageContent::ToolUse(last_tool_use) = content { + if let AgentMessageContent::ToolUse(last_tool_use) = content { if last_tool_use.id == tool_use.id { *last_tool_use = tool_use.clone(); false @@ -505,18 +731,11 @@ impl Thread { } }); - let mut title = SharedString::from(&tool_use.name); - let mut kind = acp::ToolKind::Other; - if let Some(tool) = tool.as_ref() { - title = tool.initial_title(tool_use.input.clone()); - kind = tool.kind(); - } - if push_new_tool_use { event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); last_message .content - .push(MessageContent::ToolUse(tool_use.clone())); + .push(AgentMessageContent::ToolUse(tool_use.clone())); } else { event_stream.update_tool_call_fields( &tool_use.id, @@ -601,30 +820,37 @@ impl Thread { } } - /// Guarantees the last message is from the assistant and returns a mutable reference. - fn last_assistant_message(&mut self) -> &mut AgentMessage { - if self - .messages - .last() - .map_or(true, |m| m.role != Role::Assistant) - { - self.messages.push(AgentMessage { - role: Role::Assistant, - content: Vec::new(), - }); - } - self.messages.last_mut().unwrap() + fn pending_agent_message(&mut self) -> &mut AgentMessage { + self.pending_agent_message.get_or_insert_default() } - /// Guarantees the last message is from the user and returns a mutable reference. - fn last_user_message(&mut self) -> &mut AgentMessage { - if self.messages.last().map_or(true, |m| m.role != Role::User) { - self.messages.push(AgentMessage { - role: Role::User, - content: Vec::new(), - }); + fn flush_pending_agent_message(&mut self) { + let Some(mut message) = self.pending_agent_message.take() else { + return; + }; + + for content in &message.content { + let AgentMessageContent::ToolUse(tool_use) = content else { + continue; + }; + + if !message.tool_results.contains_key(&tool_use.id) { + message.tool_results.insert( + tool_use.id.clone(), + LanguageModelToolResult { + tool_use_id: tool_use.id.clone(), + tool_name: tool_use.name.clone(), + is_error: true, + content: LanguageModelToolResultContent::Text( + "Tool canceled by user".into(), + ), + output: None, + }, + ); + } } - self.messages.last_mut().unwrap() + + self.messages.push(Message::Agent(message)); } pub(crate) fn build_completion_request( @@ -712,49 +938,39 @@ impl Thread { "Building request messages from {} thread messages", self.messages.len() ); + let mut messages = vec![self.build_system_message()]; + for message in &self.messages { + match message { + Message::User(message) => messages.push(message.to_request()), + Message::Agent(message) => messages.extend(message.to_request()), + } + } + + if let Some(message) = self.pending_agent_message.as_ref() { + messages.extend(message.to_request()); + } - let messages = Some(self.build_system_message()) - .iter() - .chain(self.messages.iter()) - .map(|message| { - log::trace!( - " - {} message with {} content items", - match message.role { - Role::System => "System", - Role::User => "User", - Role::Assistant => "Assistant", - }, - message.content.len() - ); - message.to_request() - }) - .collect(); messages } pub fn to_markdown(&self) -> String { let mut markdown = String::new(); - for message in &self.messages { + for (ix, message) in self.messages.iter().enumerate() { + if ix > 0 { + markdown.push('\n'); + } markdown.push_str(&message.to_markdown()); } + + if let Some(message) = self.pending_agent_message.as_ref() { + markdown.push('\n'); + markdown.push_str(&message.to_markdown()); + } + markdown } } -pub struct UserMessage(Vec); - -impl From> for UserMessage { - fn from(content: Vec) -> Self { - UserMessage(content) - } -} - -impl> From for UserMessage { - fn from(content: T) -> Self { - UserMessage(vec![content.into()]) - } -} - pub trait AgentTool where Self: 'static + Sized, @@ -1151,130 +1367,6 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver { } } -impl AgentMessage { - fn to_request(&self) -> language_model::LanguageModelRequestMessage { - let mut message = LanguageModelRequestMessage { - role: self.role, - content: Vec::with_capacity(self.content.len()), - cache: false, - }; - - const OPEN_CONTEXT: &str = "\n\ - The following items were attached by the user. \ - They are up-to-date and don't need to be re-read.\n\n"; - - const OPEN_FILES_TAG: &str = ""; - const OPEN_SYMBOLS_TAG: &str = ""; - const OPEN_THREADS_TAG: &str = ""; - const OPEN_RULES_TAG: &str = - "\nThe user has specified the following rules that should be applied:\n"; - - let mut file_context = OPEN_FILES_TAG.to_string(); - let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); - let mut thread_context = OPEN_THREADS_TAG.to_string(); - let mut rules_context = OPEN_RULES_TAG.to_string(); - - for chunk in &self.content { - let chunk = match chunk { - MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()), - MessageContent::Thinking { text, signature } => { - language_model::MessageContent::Thinking { - text: text.clone(), - signature: signature.clone(), - } - } - MessageContent::RedactedThinking(value) => { - language_model::MessageContent::RedactedThinking(value.clone()) - } - MessageContent::ToolUse(value) => { - language_model::MessageContent::ToolUse(value.clone()) - } - MessageContent::ToolResult(value) => { - language_model::MessageContent::ToolResult(value.clone()) - } - MessageContent::Image(value) => { - language_model::MessageContent::Image(value.clone()) - } - MessageContent::Mention { uri, content } => { - match uri { - MentionUri::File(path) | MentionUri::Symbol(path, _) => { - write!( - &mut symbol_context, - "\n{}", - MarkdownCodeBlock { - tag: &codeblock_tag(&path), - text: &content.to_string(), - } - ) - .ok(); - } - MentionUri::Thread(_session_id) => { - write!(&mut thread_context, "\n{}\n", content).ok(); - } - MentionUri::Rule(_user_prompt_id) => { - write!( - &mut rules_context, - "\n{}", - MarkdownCodeBlock { - tag: "", - text: &content - } - ) - .ok(); - } - } - - language_model::MessageContent::Text(uri.to_link()) - } - }; - - message.content.push(chunk); - } - - let len_before_context = message.content.len(); - - if file_context.len() > OPEN_FILES_TAG.len() { - file_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(file_context)); - } - - if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { - symbol_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(symbol_context)); - } - - if thread_context.len() > OPEN_THREADS_TAG.len() { - thread_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(thread_context)); - } - - if rules_context.len() > OPEN_RULES_TAG.len() { - rules_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(rules_context)); - } - - if message.content.len() > len_before_context { - message.content.insert( - len_before_context, - language_model::MessageContent::Text(OPEN_CONTEXT.into()), - ); - message - .content - .push(language_model::MessageContent::Text("".into())); - } - - message - } -} - fn codeblock_tag(full_path: &Path) -> String { let mut result = String::new(); @@ -1287,16 +1379,20 @@ fn codeblock_tag(full_path: &Path) -> String { result } -impl From for MessageContent { +impl From<&str> for UserMessageContent { + fn from(text: &str) -> Self { + Self::Text(text.into()) + } +} + +impl From for UserMessageContent { fn from(value: acp::ContentBlock) -> Self { match value { - acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text), - acp::ContentBlock::Image(image_content) => { - MessageContent::Image(convert_image(image_content)) - } + acp::ContentBlock::Text(text_content) => Self::Text(text_content.text), + acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)), acp::ContentBlock::Audio(_) => { // TODO - MessageContent::Text("[audio]".to_string()) + Self::Text("[audio]".to_string()) } acp::ContentBlock::ResourceLink(resource_link) => { match MentionUri::parse(&resource_link.uri) { @@ -1306,10 +1402,7 @@ impl From for MessageContent { }, Err(err) => { log::error!("Failed to parse mention link: {}", err); - MessageContent::Text(format!( - "[{}]({})", - resource_link.name, resource_link.uri - )) + Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri)) } } } @@ -1322,7 +1415,7 @@ impl From for MessageContent { }, Err(err) => { log::error!("Failed to parse mention link: {}", err); - MessageContent::Text( + Self::Text( MarkdownCodeBlock { tag: &resource.uri, text: &resource.text, @@ -1334,7 +1427,7 @@ impl From for MessageContent { } acp::EmbeddedResourceResource::BlobResourceContents(_) => { // TODO - MessageContent::Text("[blob]".to_string()) + Self::Text("[blob]".to_string()) } }, } @@ -1348,9 +1441,3 @@ fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { size: gpui::Size::new(0.into(), 0.into()), } } - -impl From<&str> for MessageContent { - fn from(text: &str) -> Self { - MessageContent::Text(text.into()) - } -} diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 8d85435f92..327613de67 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -467,6 +467,7 @@ impl AgentConnection for AcpConnection { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index ff71783b48..de397fddf0 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -171,6 +171,7 @@ impl AgentConnection for AcpConnection { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index c65508f152..c394ec4a9c 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -210,6 +210,7 @@ impl AgentConnection for ClaudeAgentConnection { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -423,7 +424,7 @@ impl ClaudeAgentSession { if !turn_state.borrow().is_cancelled() { thread .update(cx, |thread, cx| { - thread.push_user_content_block(text.into(), cx) + thread.push_user_content_block(None, text.into(), cx) }) .log_err(); } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 12fc29b08f..0b3ace1baf 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -679,17 +679,19 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let count = self.list_state.item_count(); match event { AcpThreadEvent::NewEntry => { let index = thread.read(cx).entries().len() - 1; self.sync_thread_entry_view(index, window, cx); - self.list_state.splice(count..count, 1); + self.list_state.splice(index..index, 1); } AcpThreadEvent::EntryUpdated(index) => { - let index = *index; - self.sync_thread_entry_view(index, window, cx); - self.list_state.splice(index..index + 1, 1); + self.sync_thread_entry_view(*index, window, cx); + self.list_state.splice(*index..index + 1, 1); + } + AcpThreadEvent::EntriesRemoved(range) => { + // TODO: Clean up unused diff editors and terminal views + self.list_state.splice(range.clone(), 0); } AcpThreadEvent::ToolAuthorizationRequired => { self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx); @@ -3789,6 +3791,7 @@ mod tests { fn prompt( &self, + _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -3873,6 +3876,7 @@ mod tests { fn prompt( &self, + _id: Option, _params: acp::PromptRequest, _cx: &mut App, ) -> Task> { diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 0abc5280f4..b9e1ea5d0a 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1521,7 +1521,8 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } } - AcpThreadEvent::Stopped + AcpThreadEvent::EntriesRemoved(_) + | AcpThreadEvent::Stopped | AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {} diff --git a/crates/fs/Cargo.toml b/crates/fs/Cargo.toml index 633fc1fc99..1d4161134e 100644 --- a/crates/fs/Cargo.toml +++ b/crates/fs/Cargo.toml @@ -51,6 +51,7 @@ ashpd.workspace = true [dev-dependencies] gpui = { workspace = true, features = ["test-support"] } +git = { workspace = true, features = ["test-support"] } [features] test-support = ["gpui/test-support", "git/test-support"] diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 21b9cbca9a..f0936d400a 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -1,8 +1,9 @@ -use crate::{FakeFs, Fs}; +use crate::{FakeFs, FakeFsEntry, Fs}; use anyhow::{Context as _, Result}; use collections::{HashMap, HashSet}; use futures::future::{self, BoxFuture, join_all}; use git::{ + Oid, blame::Blame, repository::{ AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository, @@ -12,6 +13,7 @@ use git::{ }; use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task}; use ignore::gitignore::GitignoreBuilder; +use parking_lot::Mutex; use rope::Rope; use smol::future::FutureExt as _; use std::{path::PathBuf, sync::Arc}; @@ -19,6 +21,7 @@ use std::{path::PathBuf, sync::Arc}; #[derive(Clone)] pub struct FakeGitRepository { pub(crate) fs: Arc, + pub(crate) checkpoints: Arc>>, pub(crate) executor: BackgroundExecutor, pub(crate) dot_git_path: PathBuf, pub(crate) repository_dir_path: PathBuf, @@ -469,22 +472,57 @@ impl GitRepository for FakeGitRepository { } fn checkpoint(&self) -> BoxFuture<'static, Result> { - unimplemented!() + let executor = self.executor.clone(); + let fs = self.fs.clone(); + let checkpoints = self.checkpoints.clone(); + let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf(); + async move { + executor.simulate_random_delay().await; + let oid = Oid::random(&mut executor.rng()); + let entry = fs.entry(&repository_dir_path)?; + checkpoints.lock().insert(oid, entry); + Ok(GitRepositoryCheckpoint { commit_sha: oid }) + } + .boxed() } - fn restore_checkpoint( - &self, - _checkpoint: GitRepositoryCheckpoint, - ) -> BoxFuture<'_, Result<()>> { - unimplemented!() + fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> { + let executor = self.executor.clone(); + let fs = self.fs.clone(); + let checkpoints = self.checkpoints.clone(); + let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf(); + async move { + executor.simulate_random_delay().await; + let checkpoints = checkpoints.lock(); + let entry = checkpoints + .get(&checkpoint.commit_sha) + .context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?; + fs.insert_entry(&repository_dir_path, entry.clone())?; + Ok(()) + } + .boxed() } fn compare_checkpoints( &self, - _left: GitRepositoryCheckpoint, - _right: GitRepositoryCheckpoint, + left: GitRepositoryCheckpoint, + right: GitRepositoryCheckpoint, ) -> BoxFuture<'_, Result> { - unimplemented!() + let executor = self.executor.clone(); + let checkpoints = self.checkpoints.clone(); + async move { + executor.simulate_random_delay().await; + let checkpoints = checkpoints.lock(); + let left = checkpoints + .get(&left.commit_sha) + .context(format!("invalid left checkpoint: {}", left.commit_sha))?; + let right = checkpoints + .get(&right.commit_sha) + .context(format!("invalid right checkpoint: {}", right.commit_sha))?; + + Ok(left == right) + } + .boxed() } fn diff_checkpoints( @@ -499,3 +537,63 @@ impl GitRepository for FakeGitRepository { unimplemented!() } } + +#[cfg(test)] +mod tests { + use crate::{FakeFs, Fs}; + use gpui::BackgroundExecutor; + use serde_json::json; + use std::path::Path; + use util::path; + + #[gpui::test] + async fn test_checkpoints(executor: BackgroundExecutor) { + let fs = FakeFs::new(executor); + fs.insert_tree( + path!("/"), + json!({ + "bar": { + "baz": "qux" + }, + "foo": { + ".git": {}, + "a": "lorem", + "b": "ipsum", + }, + }), + ) + .await; + fs.with_git_state(Path::new("/foo/.git"), true, |_git| {}) + .unwrap(); + let repository = fs.open_repo(Path::new("/foo/.git")).unwrap(); + + let checkpoint_1 = repository.checkpoint().await.unwrap(); + fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap(); + fs.write(Path::new("/foo/c"), b"dolor").await.unwrap(); + let checkpoint_2 = repository.checkpoint().await.unwrap(); + let checkpoint_3 = repository.checkpoint().await.unwrap(); + + assert!( + repository + .compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone()) + .await + .unwrap() + ); + assert!( + !repository + .compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone()) + .await + .unwrap() + ); + + repository.restore_checkpoint(checkpoint_1).await.unwrap(); + assert_eq!( + fs.files_with_contents(Path::new("")), + [ + (Path::new("/bar/baz").into(), b"qux".into()), + (Path::new("/foo/a").into(), b"lorem".into()), + (Path::new("/foo/b").into(), b"ipsum".into()) + ] + ); + } +} diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index a2b75ac6a7..22bfdbcd66 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -924,7 +924,7 @@ pub struct FakeFs { #[cfg(any(test, feature = "test-support"))] struct FakeFsState { - root: Arc>, + root: FakeFsEntry, next_inode: u64, next_mtime: SystemTime, git_event_tx: smol::channel::Sender, @@ -939,7 +939,7 @@ struct FakeFsState { } #[cfg(any(test, feature = "test-support"))] -#[derive(Debug)] +#[derive(Clone, Debug)] enum FakeFsEntry { File { inode: u64, @@ -953,7 +953,7 @@ enum FakeFsEntry { inode: u64, mtime: MTime, len: u64, - entries: BTreeMap>>, + entries: BTreeMap, git_repo_state: Option>>, }, Symlink { @@ -961,6 +961,67 @@ enum FakeFsEntry { }, } +#[cfg(any(test, feature = "test-support"))] +impl PartialEq for FakeFsEntry { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Self::File { + inode: l_inode, + mtime: l_mtime, + len: l_len, + content: l_content, + git_dir_path: l_git_dir_path, + }, + Self::File { + inode: r_inode, + mtime: r_mtime, + len: r_len, + content: r_content, + git_dir_path: r_git_dir_path, + }, + ) => { + l_inode == r_inode + && l_mtime == r_mtime + && l_len == r_len + && l_content == r_content + && l_git_dir_path == r_git_dir_path + } + ( + Self::Dir { + inode: l_inode, + mtime: l_mtime, + len: l_len, + entries: l_entries, + git_repo_state: l_git_repo_state, + }, + Self::Dir { + inode: r_inode, + mtime: r_mtime, + len: r_len, + entries: r_entries, + git_repo_state: r_git_repo_state, + }, + ) => { + let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) { + (Some(l), Some(r)) => Arc::ptr_eq(l, r), + (None, None) => true, + _ => false, + }; + l_inode == r_inode + && l_mtime == r_mtime + && l_len == r_len + && l_entries == r_entries + && same_repo_state + } + (Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => { + l_target == r_target + } + _ => false, + } + } +} + #[cfg(any(test, feature = "test-support"))] impl FakeFsState { fn get_and_increment_mtime(&mut self) -> MTime { @@ -975,25 +1036,9 @@ impl FakeFsState { inode } - fn read_path(&self, target: &Path) -> Result>> { - Ok(self - .try_read_path(target, true) - .ok_or_else(|| { - anyhow!(io::Error::new( - io::ErrorKind::NotFound, - format!("not found: {target:?}") - )) - })? - .0) - } - - fn try_read_path( - &self, - target: &Path, - follow_symlink: bool, - ) -> Option<(Arc>, PathBuf)> { - let mut path = target.to_path_buf(); + fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option { let mut canonical_path = PathBuf::new(); + let mut path = target.to_path_buf(); let mut entry_stack = Vec::new(); 'outer: loop { let mut path_components = path.components().peekable(); @@ -1003,7 +1048,7 @@ impl FakeFsState { Component::Prefix(prefix_component) => prefix = Some(prefix_component), Component::RootDir => { entry_stack.clear(); - entry_stack.push(self.root.clone()); + entry_stack.push(&self.root); canonical_path.clear(); match prefix { Some(prefix_component) => { @@ -1020,20 +1065,18 @@ impl FakeFsState { canonical_path.pop(); } Component::Normal(name) => { - let current_entry = entry_stack.last().cloned()?; - let current_entry = current_entry.lock(); - if let FakeFsEntry::Dir { entries, .. } = &*current_entry { - let entry = entries.get(name.to_str().unwrap()).cloned()?; + let current_entry = *entry_stack.last()?; + if let FakeFsEntry::Dir { entries, .. } = current_entry { + let entry = entries.get(name.to_str().unwrap())?; if path_components.peek().is_some() || follow_symlink { - let entry = entry.lock(); - if let FakeFsEntry::Symlink { target, .. } = &*entry { + if let FakeFsEntry::Symlink { target, .. } = entry { let mut target = target.clone(); target.extend(path_components); path = target; continue 'outer; } } - entry_stack.push(entry.clone()); + entry_stack.push(entry); canonical_path = canonical_path.join(name); } else { return None; @@ -1043,19 +1086,72 @@ impl FakeFsState { } break; } - Some((entry_stack.pop()?, canonical_path)) + + if entry_stack.is_empty() { + None + } else { + Some(canonical_path) + } } - fn write_path(&self, path: &Path, callback: Fn) -> Result + fn try_entry( + &mut self, + target: &Path, + follow_symlink: bool, + ) -> Option<(&mut FakeFsEntry, PathBuf)> { + let canonical_path = self.canonicalize(target, follow_symlink)?; + + let mut components = canonical_path.components(); + let Some(Component::RootDir) = components.next() else { + panic!( + "the path {:?} was not canonicalized properly {:?}", + target, canonical_path + ) + }; + + let mut entry = &mut self.root; + for component in components { + match component { + Component::Normal(name) => { + if let FakeFsEntry::Dir { entries, .. } = entry { + entry = entries.get_mut(name.to_str().unwrap())?; + } else { + return None; + } + } + _ => { + panic!( + "the path {:?} was not canonicalized properly {:?}", + target, canonical_path + ) + } + } + } + + Some((entry, canonical_path)) + } + + fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> { + Ok(self + .try_entry(target, true) + .ok_or_else(|| { + anyhow!(io::Error::new( + io::ErrorKind::NotFound, + format!("not found: {target:?}") + )) + })? + .0) + } + + fn write_path(&mut self, path: &Path, callback: Fn) -> Result where - Fn: FnOnce(btree_map::Entry>>) -> Result, + Fn: FnOnce(btree_map::Entry) -> Result, { let path = normalize_path(path); let filename = path.file_name().context("cannot overwrite the root")?; let parent_path = path.parent().unwrap(); - let parent = self.read_path(parent_path)?; - let mut parent = parent.lock(); + let parent = self.entry(parent_path)?; let new_entry = parent .dir_entries(parent_path)? .entry(filename.to_str().unwrap().into()); @@ -1105,13 +1201,13 @@ impl FakeFs { this: this.clone(), executor: executor.clone(), state: Arc::new(Mutex::new(FakeFsState { - root: Arc::new(Mutex::new(FakeFsEntry::Dir { + root: FakeFsEntry::Dir { inode: 0, mtime: MTime(UNIX_EPOCH), len: 0, entries: Default::default(), git_repo_state: None, - })), + }, git_event_tx: tx, next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL, next_inode: 1, @@ -1161,15 +1257,15 @@ impl FakeFs { .write_path(path, move |entry| { match entry { btree_map::Entry::Vacant(e) => { - e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + e.insert(FakeFsEntry::File { inode: new_inode, mtime: new_mtime, content: Vec::new(), len: 0, git_dir_path: None, - }))); + }); } - btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() { + btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() { FakeFsEntry::File { mtime, .. } => *mtime = new_mtime, FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime, FakeFsEntry::Symlink { .. } => {} @@ -1188,7 +1284,7 @@ impl FakeFs { pub async fn insert_symlink(&self, path: impl AsRef, target: PathBuf) { let mut state = self.state.lock(); let path = path.as_ref(); - let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target })); + let file = FakeFsEntry::Symlink { target }; state .write_path(path.as_ref(), move |e| match e { btree_map::Entry::Vacant(e) => { @@ -1221,13 +1317,13 @@ impl FakeFs { match entry { btree_map::Entry::Vacant(e) => { kind = Some(PathEventKind::Created); - e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + e.insert(FakeFsEntry::File { inode: new_inode, mtime: new_mtime, len: new_len, content: new_content, git_dir_path: None, - }))); + }); } btree_map::Entry::Occupied(mut e) => { kind = Some(PathEventKind::Changed); @@ -1237,7 +1333,7 @@ impl FakeFs { len, content, .. - } = &mut *e.get_mut().lock() + } = e.get_mut() { *mtime = new_mtime; *content = new_content; @@ -1259,9 +1355,8 @@ impl FakeFs { pub fn read_file_sync(&self, path: impl AsRef) -> Result> { let path = path.as_ref(); let path = normalize_path(path); - let state = self.state.lock(); - let entry = state.read_path(&path)?; - let entry = entry.lock(); + let mut state = self.state.lock(); + let entry = state.entry(&path)?; entry.file_content(&path).cloned() } @@ -1269,9 +1364,8 @@ impl FakeFs { let path = path.as_ref(); let path = normalize_path(path); self.simulate_random_delay().await; - let state = self.state.lock(); - let entry = state.read_path(&path)?; - let entry = entry.lock(); + let mut state = self.state.lock(); + let entry = state.entry(&path)?; entry.file_content(&path).cloned() } @@ -1292,6 +1386,25 @@ impl FakeFs { self.state.lock().flush_events(count); } + pub(crate) fn entry(&self, target: &Path) -> Result { + self.state.lock().entry(target).cloned() + } + + pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> { + let mut state = self.state.lock(); + state.write_path(target, |entry| { + match entry { + btree_map::Entry::Vacant(vacant_entry) => { + vacant_entry.insert(new_entry); + } + btree_map::Entry::Occupied(mut occupied_entry) => { + occupied_entry.insert(new_entry); + } + } + Ok(()) + }) + } + #[must_use] pub fn insert_tree<'a>( &'a self, @@ -1361,20 +1474,19 @@ impl FakeFs { F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T, { let mut state = self.state.lock(); - let entry = state.read_path(dot_git).context("open .git")?; - let mut entry = entry.lock(); + let git_event_tx = state.git_event_tx.clone(); + let entry = state.entry(dot_git).context("open .git")?; - if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry { + if let FakeFsEntry::Dir { git_repo_state, .. } = entry { let repo_state = git_repo_state.get_or_insert_with(|| { log::debug!("insert git state for {dot_git:?}"); - Arc::new(Mutex::new(FakeGitRepositoryState::new( - state.git_event_tx.clone(), - ))) + Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx))) }); let mut repo_state = repo_state.lock(); let result = f(&mut repo_state, dot_git, dot_git); + drop(repo_state); if emit_git_event { state.emit_event([(dot_git, None)]); } @@ -1398,21 +1510,20 @@ impl FakeFs { } } .clone(); - drop(entry); - let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else { + let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else { anyhow::bail!("pointed-to git dir {path:?} not found") }; let FakeFsEntry::Dir { git_repo_state, entries, .. - } = &mut *git_dir_entry.lock() + } = git_dir_entry else { anyhow::bail!("gitfile points to a non-directory") }; let common_dir = if let Some(child) = entries.get("commondir") { Path::new( - std::str::from_utf8(child.lock().file_content("commondir".as_ref())?) + std::str::from_utf8(child.file_content("commondir".as_ref())?) .context("commondir content")?, ) .to_owned() @@ -1420,15 +1531,14 @@ impl FakeFs { canonical_path.clone() }; let repo_state = git_repo_state.get_or_insert_with(|| { - Arc::new(Mutex::new(FakeGitRepositoryState::new( - state.git_event_tx.clone(), - ))) + Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx))) }); let mut repo_state = repo_state.lock(); let result = f(&mut repo_state, &canonical_path, &common_dir); if emit_git_event { + drop(repo_state); state.emit_event([(canonical_path, None)]); } @@ -1655,14 +1765,12 @@ impl FakeFs { pub fn paths(&self, include_dot_git: bool) -> Vec { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() { + if let FakeFsEntry::Dir { entries, .. } = entry { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } } if include_dot_git @@ -1679,14 +1787,12 @@ impl FakeFs { pub fn directories(&self, include_dot_git: bool) -> Vec { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() { + if let FakeFsEntry::Dir { entries, .. } = entry { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } if include_dot_git || !path @@ -1703,17 +1809,14 @@ impl FakeFs { pub fn files(&self) -> Vec { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - let e = entry.lock(); - match &*e { + match entry { FakeFsEntry::File { .. } => result.push(path), FakeFsEntry::Dir { entries, .. } => { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } } FakeFsEntry::Symlink { .. } => {} @@ -1725,13 +1828,10 @@ impl FakeFs { pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec)> { let mut result = Vec::new(); let mut queue = collections::VecDeque::new(); - queue.push_back(( - PathBuf::from(util::path!("/")), - self.state.lock().root.clone(), - )); + let state = &*self.state.lock(); + queue.push_back((PathBuf::from(util::path!("/")), &state.root)); while let Some((path, entry)) = queue.pop_front() { - let e = entry.lock(); - match &*e { + match entry { FakeFsEntry::File { content, .. } => { if path.starts_with(prefix) { result.push((path, content.clone())); @@ -1739,7 +1839,7 @@ impl FakeFs { } FakeFsEntry::Dir { entries, .. } => { for (name, entry) in entries { - queue.push_back((path.join(name), entry.clone())); + queue.push_back((path.join(name), entry)); } } FakeFsEntry::Symlink { .. } => {} @@ -1805,10 +1905,7 @@ impl FakeFsEntry { } } - fn dir_entries( - &mut self, - path: &Path, - ) -> Result<&mut BTreeMap>>> { + fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap> { if let Self::Dir { entries, .. } = self { Ok(entries) } else { @@ -1855,12 +1952,12 @@ struct FakeHandle { impl FileHandle for FakeHandle { fn current_path(&self, fs: &Arc) -> Result { let fs = fs.as_fake(); - let state = fs.state.lock(); - let Some(target) = state.moves.get(&self.inode) else { + let mut state = fs.state.lock(); + let Some(target) = state.moves.get(&self.inode).cloned() else { anyhow::bail!("fake fd not moved") }; - if state.try_read_path(&target, false).is_some() { + if state.try_entry(&target, false).is_some() { return Ok(target.clone()); } anyhow::bail!("fake fd target not found") @@ -1888,13 +1985,13 @@ impl Fs for FakeFs { state.write_path(&cur_path, |entry| { entry.or_insert_with(|| { created_dirs.push((cur_path.clone(), Some(PathEventKind::Created))); - Arc::new(Mutex::new(FakeFsEntry::Dir { + FakeFsEntry::Dir { inode, mtime, len: 0, entries: Default::default(), git_repo_state: None, - })) + } }); Ok(()) })? @@ -1909,13 +2006,13 @@ impl Fs for FakeFs { let mut state = self.state.lock(); let inode = state.get_and_increment_inode(); let mtime = state.get_and_increment_mtime(); - let file = Arc::new(Mutex::new(FakeFsEntry::File { + let file = FakeFsEntry::File { inode, mtime, len: 0, content: Vec::new(), git_dir_path: None, - })); + }; let mut kind = Some(PathEventKind::Created); state.write_path(path, |entry| { match entry { @@ -1939,7 +2036,7 @@ impl Fs for FakeFs { async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> { let mut state = self.state.lock(); - let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target })); + let file = FakeFsEntry::Symlink { target }; state .write_path(path.as_ref(), move |e| match e { btree_map::Entry::Vacant(e) => { @@ -2002,7 +2099,7 @@ impl Fs for FakeFs { } })?; - let inode = match *moved_entry.lock() { + let inode = match moved_entry { FakeFsEntry::File { inode, .. } => inode, FakeFsEntry::Dir { inode, .. } => inode, _ => 0, @@ -2051,8 +2148,8 @@ impl Fs for FakeFs { let mut state = self.state.lock(); let mtime = state.get_and_increment_mtime(); let inode = state.get_and_increment_inode(); - let source_entry = state.read_path(&source)?; - let content = source_entry.lock().file_content(&source)?.clone(); + let source_entry = state.entry(&source)?; + let content = source_entry.file_content(&source)?.clone(); let mut kind = Some(PathEventKind::Created); state.write_path(&target, |e| match e { btree_map::Entry::Occupied(e) => { @@ -2066,13 +2163,13 @@ impl Fs for FakeFs { } } btree_map::Entry::Vacant(e) => Ok(Some( - e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + e.insert(FakeFsEntry::File { inode, mtime, len: content.len() as u64, content, git_dir_path: None, - }))) + }) .clone(), )), })?; @@ -2088,8 +2185,7 @@ impl Fs for FakeFs { let base_name = path.file_name().context("cannot remove the root")?; let mut state = self.state.lock(); - let parent_entry = state.read_path(parent_path)?; - let mut parent_entry = parent_entry.lock(); + let parent_entry = state.entry(parent_path)?; let entry = parent_entry .dir_entries(parent_path)? .entry(base_name.to_str().unwrap().into()); @@ -2100,15 +2196,14 @@ impl Fs for FakeFs { anyhow::bail!("{path:?} does not exist"); } } - btree_map::Entry::Occupied(e) => { + btree_map::Entry::Occupied(mut entry) => { { - let mut entry = e.get().lock(); - let children = entry.dir_entries(&path)?; + let children = entry.get_mut().dir_entries(&path)?; if !options.recursive && !children.is_empty() { anyhow::bail!("{path:?} is not empty"); } } - e.remove(); + entry.remove(); } } state.emit_event([(path, Some(PathEventKind::Removed))]); @@ -2122,8 +2217,7 @@ impl Fs for FakeFs { let parent_path = path.parent().context("cannot remove the root")?; let base_name = path.file_name().unwrap(); let mut state = self.state.lock(); - let parent_entry = state.read_path(parent_path)?; - let mut parent_entry = parent_entry.lock(); + let parent_entry = state.entry(parent_path)?; let entry = parent_entry .dir_entries(parent_path)? .entry(base_name.to_str().unwrap().into()); @@ -2133,9 +2227,9 @@ impl Fs for FakeFs { anyhow::bail!("{path:?} does not exist"); } } - btree_map::Entry::Occupied(e) => { - e.get().lock().file_content(&path)?; - e.remove(); + btree_map::Entry::Occupied(mut entry) => { + entry.get_mut().file_content(&path)?; + entry.remove(); } } state.emit_event([(path, Some(PathEventKind::Removed))]); @@ -2149,12 +2243,10 @@ impl Fs for FakeFs { async fn open_handle(&self, path: &Path) -> Result> { self.simulate_random_delay().await; - let state = self.state.lock(); - let entry = state.read_path(&path)?; - let entry = entry.lock(); - let inode = match *entry { - FakeFsEntry::File { inode, .. } => inode, - FakeFsEntry::Dir { inode, .. } => inode, + let mut state = self.state.lock(); + let inode = match state.entry(&path)? { + FakeFsEntry::File { inode, .. } => *inode, + FakeFsEntry::Dir { inode, .. } => *inode, _ => unreachable!(), }; Ok(Arc::new(FakeHandle { inode })) @@ -2204,8 +2296,8 @@ impl Fs for FakeFs { let path = normalize_path(path); self.simulate_random_delay().await; let state = self.state.lock(); - let (_, canonical_path) = state - .try_read_path(&path, true) + let canonical_path = state + .canonicalize(&path, true) .with_context(|| format!("path does not exist: {path:?}"))?; Ok(canonical_path) } @@ -2213,9 +2305,9 @@ impl Fs for FakeFs { async fn is_file(&self, path: &Path) -> bool { let path = normalize_path(path); self.simulate_random_delay().await; - let state = self.state.lock(); - if let Some((entry, _)) = state.try_read_path(&path, true) { - entry.lock().is_file() + let mut state = self.state.lock(); + if let Some((entry, _)) = state.try_entry(&path, true) { + entry.is_file() } else { false } @@ -2232,17 +2324,16 @@ impl Fs for FakeFs { let path = normalize_path(path); let mut state = self.state.lock(); state.metadata_call_count += 1; - if let Some((mut entry, _)) = state.try_read_path(&path, false) { - let is_symlink = entry.lock().is_symlink(); + if let Some((mut entry, _)) = state.try_entry(&path, false) { + let is_symlink = entry.is_symlink(); if is_symlink { - if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) { + if let Some(e) = state.try_entry(&path, true).map(|e| e.0) { entry = e; } else { return Ok(None); } } - let entry = entry.lock(); Ok(Some(match &*entry { FakeFsEntry::File { inode, mtime, len, .. @@ -2274,12 +2365,11 @@ impl Fs for FakeFs { async fn read_link(&self, path: &Path) -> Result { self.simulate_random_delay().await; let path = normalize_path(path); - let state = self.state.lock(); + let mut state = self.state.lock(); let (entry, _) = state - .try_read_path(&path, false) + .try_entry(&path, false) .with_context(|| format!("path does not exist: {path:?}"))?; - let entry = entry.lock(); - if let FakeFsEntry::Symlink { target } = &*entry { + if let FakeFsEntry::Symlink { target } = entry { Ok(target.clone()) } else { anyhow::bail!("not a symlink: {path:?}") @@ -2294,8 +2384,7 @@ impl Fs for FakeFs { let path = normalize_path(path); let mut state = self.state.lock(); state.read_dir_call_count += 1; - let entry = state.read_path(&path)?; - let mut entry = entry.lock(); + let entry = state.entry(&path)?; let children = entry.dir_entries(&path)?; let paths = children .keys() @@ -2359,6 +2448,7 @@ impl Fs for FakeFs { dot_git_path: abs_dot_git.to_path_buf(), repository_dir_path: repository_dir_path.to_owned(), common_dir_path: common_dir_path.to_owned(), + checkpoints: Arc::default(), }) as _ }, ) diff --git a/crates/git/Cargo.toml b/crates/git/Cargo.toml index ab2210094d..74656f1d4c 100644 --- a/crates/git/Cargo.toml +++ b/crates/git/Cargo.toml @@ -12,7 +12,7 @@ workspace = true path = "src/git.rs" [features] -test-support = [] +test-support = ["rand"] [dependencies] anyhow.workspace = true @@ -26,6 +26,7 @@ http_client.workspace = true log.workspace = true parking_lot.workspace = true regex.workspace = true +rand = { workspace = true, optional = true } rope.workspace = true schemars.workspace = true serde.workspace = true @@ -47,3 +48,4 @@ text = { workspace = true, features = ["test-support"] } unindent.workspace = true gpui = { workspace = true, features = ["test-support"] } tempfile.workspace = true +rand.workspace = true diff --git a/crates/git/src/git.rs b/crates/git/src/git.rs index e6336eb656..e84014129c 100644 --- a/crates/git/src/git.rs +++ b/crates/git/src/git.rs @@ -119,6 +119,13 @@ impl Oid { Ok(Self(oid)) } + #[cfg(any(test, feature = "test-support"))] + pub fn random(rng: &mut impl rand::Rng) -> Self { + let mut bytes = [0; 20]; + rng.fill(&mut bytes); + Self::from_bytes(&bytes).unwrap() + } + pub fn as_bytes(&self) -> &[u8] { self.0.as_bytes() }