agent2: Initial infra for checkpoints and message editing (#36120)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
f4b0332f78
commit
23cd5b59b2
17 changed files with 1374 additions and 582 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -31,6 +31,7 @@ dependencies = [
|
||||||
"ui",
|
"ui",
|
||||||
"url",
|
"url",
|
||||||
"util",
|
"util",
|
||||||
|
"uuid",
|
||||||
"watch",
|
"watch",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
]
|
]
|
||||||
|
@ -6446,6 +6447,7 @@ dependencies = [
|
||||||
"log",
|
"log",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
|
"rand 0.8.5",
|
||||||
"regex",
|
"regex",
|
||||||
"rope",
|
"rope",
|
||||||
"schemars",
|
"schemars",
|
||||||
|
|
|
@ -36,6 +36,7 @@ terminal.workspace = true
|
||||||
ui.workspace = true
|
ui.workspace = true
|
||||||
url.workspace = true
|
url.workspace = true
|
||||||
util.workspace = true
|
util.workspace = true
|
||||||
|
uuid.workspace = true
|
||||||
watch.workspace = true
|
watch.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
||||||
|
|
|
@ -9,18 +9,19 @@ pub use mention::*;
|
||||||
pub use terminal::*;
|
pub use terminal::*;
|
||||||
|
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol::{self as acp};
|
use agent_client_protocol as acp;
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use editor::Bias;
|
use editor::Bias;
|
||||||
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
|
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
|
||||||
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
|
use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
|
||||||
use markdown::Markdown;
|
use markdown::Markdown;
|
||||||
use project::{AgentLocation, Project};
|
use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt::Formatter;
|
use std::fmt::{Formatter, Write};
|
||||||
|
use std::ops::Range;
|
||||||
use std::process::ExitStatus;
|
use std::process::ExitStatus;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
||||||
|
@ -29,24 +30,23 @@ use util::ResultExt;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct UserMessage {
|
pub struct UserMessage {
|
||||||
|
pub id: Option<UserMessageId>,
|
||||||
pub content: ContentBlock,
|
pub content: ContentBlock,
|
||||||
|
pub checkpoint: Option<GitStoreCheckpoint>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UserMessage {
|
impl UserMessage {
|
||||||
pub fn from_acp(
|
|
||||||
message: impl IntoIterator<Item = acp::ContentBlock>,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Self {
|
|
||||||
let mut content = ContentBlock::Empty;
|
|
||||||
for chunk in message {
|
|
||||||
content.append(chunk, &language_registry, cx)
|
|
||||||
}
|
|
||||||
Self { content: content }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_markdown(&self, cx: &App) -> String {
|
fn to_markdown(&self, cx: &App) -> String {
|
||||||
format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
|
let mut markdown = String::new();
|
||||||
|
if let Some(_) = self.checkpoint {
|
||||||
|
writeln!(markdown, "## User (checkpoint)").unwrap();
|
||||||
|
} else {
|
||||||
|
writeln!(markdown, "## User").unwrap();
|
||||||
|
}
|
||||||
|
writeln!(markdown).unwrap();
|
||||||
|
writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
|
||||||
|
writeln!(markdown).unwrap();
|
||||||
|
markdown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -633,6 +633,7 @@ pub struct AcpThread {
|
||||||
pub enum AcpThreadEvent {
|
pub enum AcpThreadEvent {
|
||||||
NewEntry,
|
NewEntry,
|
||||||
EntryUpdated(usize),
|
EntryUpdated(usize),
|
||||||
|
EntriesRemoved(Range<usize>),
|
||||||
ToolAuthorizationRequired,
|
ToolAuthorizationRequired,
|
||||||
Stopped,
|
Stopped,
|
||||||
Error,
|
Error,
|
||||||
|
@ -772,7 +773,7 @@ impl AcpThread {
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
match update {
|
match update {
|
||||||
acp::SessionUpdate::UserMessageChunk { content } => {
|
acp::SessionUpdate::UserMessageChunk { content } => {
|
||||||
self.push_user_content_block(content, cx);
|
self.push_user_content_block(None, content, cx);
|
||||||
}
|
}
|
||||||
acp::SessionUpdate::AgentMessageChunk { content } => {
|
acp::SessionUpdate::AgentMessageChunk { content } => {
|
||||||
self.push_assistant_content_block(content, false, cx);
|
self.push_assistant_content_block(content, false, cx);
|
||||||
|
@ -793,18 +794,32 @@ impl AcpThread {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
|
pub fn push_user_content_block(
|
||||||
|
&mut self,
|
||||||
|
message_id: Option<UserMessageId>,
|
||||||
|
chunk: acp::ContentBlock,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
let language_registry = self.project.read(cx).languages().clone();
|
let language_registry = self.project.read(cx).languages().clone();
|
||||||
let entries_len = self.entries.len();
|
let entries_len = self.entries.len();
|
||||||
|
|
||||||
if let Some(last_entry) = self.entries.last_mut()
|
if let Some(last_entry) = self.entries.last_mut()
|
||||||
&& let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
|
&& let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry
|
||||||
{
|
{
|
||||||
|
*id = message_id.or(id.take());
|
||||||
content.append(chunk, &language_registry, cx);
|
content.append(chunk, &language_registry, cx);
|
||||||
cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
|
let idx = entries_len - 1;
|
||||||
|
cx.emit(AcpThreadEvent::EntryUpdated(idx));
|
||||||
} else {
|
} else {
|
||||||
let content = ContentBlock::new(chunk, &language_registry, cx);
|
let content = ContentBlock::new(chunk, &language_registry, cx);
|
||||||
self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
|
self.push_entry(
|
||||||
|
AgentThreadEntry::UserMessage(UserMessage {
|
||||||
|
id: message_id,
|
||||||
|
content,
|
||||||
|
checkpoint: None,
|
||||||
|
}),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -819,7 +834,8 @@ impl AcpThread {
|
||||||
if let Some(last_entry) = self.entries.last_mut()
|
if let Some(last_entry) = self.entries.last_mut()
|
||||||
&& let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
|
&& let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
|
||||||
{
|
{
|
||||||
cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
|
let idx = entries_len - 1;
|
||||||
|
cx.emit(AcpThreadEvent::EntryUpdated(idx));
|
||||||
match (chunks.last_mut(), is_thought) {
|
match (chunks.last_mut(), is_thought) {
|
||||||
(Some(AssistantMessageChunk::Message { block }), false)
|
(Some(AssistantMessageChunk::Message { block }), false)
|
||||||
| (Some(AssistantMessageChunk::Thought { block }), true) => {
|
| (Some(AssistantMessageChunk::Thought { block }), true) => {
|
||||||
|
@ -1118,69 +1134,113 @@ impl AcpThread {
|
||||||
self.project.read(cx).languages().clone(),
|
self.project.read(cx).languages().clone(),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
let git_store = self.project.read(cx).git_store().clone();
|
||||||
|
|
||||||
|
let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
|
||||||
|
let message_id = if self
|
||||||
|
.connection
|
||||||
|
.session_editor(&self.session_id, cx)
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
Some(UserMessageId::new())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
self.push_entry(
|
self.push_entry(
|
||||||
AgentThreadEntry::UserMessage(UserMessage { content: block }),
|
AgentThreadEntry::UserMessage(UserMessage {
|
||||||
|
id: message_id.clone(),
|
||||||
|
content: block,
|
||||||
|
checkpoint: None,
|
||||||
|
}),
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
self.clear_completed_plan_entries(cx);
|
self.clear_completed_plan_entries(cx);
|
||||||
|
|
||||||
|
let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
let cancel_task = self.cancel(cx);
|
let cancel_task = self.cancel(cx);
|
||||||
|
let request = acp::PromptRequest {
|
||||||
|
prompt: message,
|
||||||
|
session_id: self.session_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
self.send_task = Some(cx.spawn(async move |this, cx| {
|
self.send_task = Some(cx.spawn({
|
||||||
async {
|
let message_id = message_id.clone();
|
||||||
|
async move |this, cx| {
|
||||||
cancel_task.await;
|
cancel_task.await;
|
||||||
|
|
||||||
let result = this
|
old_checkpoint_tx.send(old_checkpoint.await).ok();
|
||||||
.update(cx, |this, cx| {
|
if let Ok(result) = this.update(cx, |this, cx| {
|
||||||
this.connection.prompt(
|
this.connection.prompt(message_id, request, cx)
|
||||||
acp::PromptRequest {
|
}) {
|
||||||
prompt: message,
|
tx.send(result.await).log_err();
|
||||||
session_id: this.session_id.clone(),
|
}
|
||||||
},
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.await;
|
|
||||||
|
|
||||||
tx.send(result).log_err();
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
}
|
}
|
||||||
.await
|
|
||||||
.log_err();
|
|
||||||
}));
|
}));
|
||||||
|
|
||||||
cx.spawn(async move |this, cx| match rx.await {
|
cx.spawn(async move |this, cx| {
|
||||||
Ok(Err(e)) => {
|
let old_checkpoint = old_checkpoint_rx
|
||||||
this.update(cx, |this, cx| {
|
.await
|
||||||
this.send_task.take();
|
.map_err(|_| anyhow!("send canceled"))
|
||||||
cx.emit(AcpThreadEvent::Error)
|
.flatten()
|
||||||
})
|
.context("failed to get old checkpoint")
|
||||||
.log_err();
|
.log_err();
|
||||||
Err(e)?
|
|
||||||
}
|
|
||||||
result => {
|
|
||||||
let cancelled = matches!(
|
|
||||||
result,
|
|
||||||
Ok(Ok(acp::PromptResponse {
|
|
||||||
stop_reason: acp::StopReason::Cancelled
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
|
|
||||||
// We only take the task if the current prompt wasn't cancelled.
|
let response = rx.await;
|
||||||
//
|
|
||||||
// This prompt may have been cancelled because another one was sent
|
|
||||||
// while it was still generating. In these cases, dropping `send_task`
|
|
||||||
// would cause the next generation to be cancelled.
|
|
||||||
if !cancelled {
|
|
||||||
this.update(cx, |this, _cx| this.send_task.take()).ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
|
if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
|
||||||
|
let new_checkpoint = git_store
|
||||||
|
.update(cx, |git, cx| git.checkpoint(cx))?
|
||||||
|
.await
|
||||||
|
.context("failed to get new checkpoint")
|
||||||
.log_err();
|
.log_err();
|
||||||
Ok(())
|
if let Some(new_checkpoint) = new_checkpoint {
|
||||||
|
let equal = git_store
|
||||||
|
.update(cx, |git, cx| {
|
||||||
|
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
|
||||||
|
})?
|
||||||
|
.await
|
||||||
|
.unwrap_or(true);
|
||||||
|
if !equal {
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
if let Some((ix, message)) = this.user_message_mut(&message_id) {
|
||||||
|
message.checkpoint = Some(old_checkpoint);
|
||||||
|
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
match response {
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
this.send_task.take();
|
||||||
|
cx.emit(AcpThreadEvent::Error);
|
||||||
|
Err(e)
|
||||||
|
}
|
||||||
|
result => {
|
||||||
|
let cancelled = matches!(
|
||||||
|
result,
|
||||||
|
Ok(Ok(acp::PromptResponse {
|
||||||
|
stop_reason: acp::StopReason::Cancelled
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
|
||||||
|
// We only take the task if the current prompt wasn't cancelled.
|
||||||
|
//
|
||||||
|
// This prompt may have been cancelled because another one was sent
|
||||||
|
// while it was still generating. In these cases, dropping `send_task`
|
||||||
|
// would cause the next generation to be cancelled.
|
||||||
|
if !cancelled {
|
||||||
|
this.send_task.take();
|
||||||
|
}
|
||||||
|
|
||||||
|
cx.emit(AcpThreadEvent::Stopped);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})?
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
@ -1212,6 +1272,66 @@ impl AcpThread {
|
||||||
cx.foreground_executor().spawn(send_task)
|
cx.foreground_executor().spawn(send_task)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Rewinds this thread to before the entry at `index`, removing it and all
|
||||||
|
/// subsequent entries while reverting any changes made from that point.
|
||||||
|
pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||||
|
let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
|
||||||
|
return Task::ready(Err(anyhow!("not supported")));
|
||||||
|
};
|
||||||
|
let Some(message) = self.user_message(&id) else {
|
||||||
|
return Task::ready(Err(anyhow!("message not found")));
|
||||||
|
};
|
||||||
|
|
||||||
|
let checkpoint = message.checkpoint.clone();
|
||||||
|
|
||||||
|
let git_store = self.project.read(cx).git_store().clone();
|
||||||
|
cx.spawn(async move |this, cx| {
|
||||||
|
if let Some(checkpoint) = checkpoint {
|
||||||
|
git_store
|
||||||
|
.update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
cx.update(|cx| session_editor.truncate(id.clone(), cx))?
|
||||||
|
.await?;
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
if let Some((ix, _)) = this.user_message_mut(&id) {
|
||||||
|
let range = ix..this.entries.len();
|
||||||
|
this.entries.truncate(ix);
|
||||||
|
cx.emit(AcpThreadEvent::EntriesRemoved(range));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
|
||||||
|
self.entries.iter().find_map(|entry| {
|
||||||
|
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||||
|
if message.id.as_ref() == Some(&id) {
|
||||||
|
Some(message)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
|
||||||
|
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
|
||||||
|
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||||
|
if message.id.as_ref() == Some(&id) {
|
||||||
|
Some((ix, message))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn read_text_file(
|
pub fn read_text_file(
|
||||||
&self,
|
&self,
|
||||||
path: PathBuf,
|
path: PathBuf,
|
||||||
|
@ -1414,13 +1534,18 @@ mod tests {
|
||||||
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
||||||
use gpui::{AsyncApp, TestAppContext, WeakEntity};
|
use gpui::{AsyncApp, TestAppContext, WeakEntity};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use project::FakeFs;
|
use project::{FakeFs, Fs};
|
||||||
use rand::Rng as _;
|
use rand::Rng as _;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::stream::StreamExt as _;
|
use smol::stream::StreamExt as _;
|
||||||
use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
|
use std::{
|
||||||
|
cell::RefCell,
|
||||||
|
path::Path,
|
||||||
|
rc::Rc,
|
||||||
|
sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
use util::path;
|
use util::path;
|
||||||
|
|
||||||
fn init_test(cx: &mut TestAppContext) {
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
|
@ -1452,6 +1577,7 @@ mod tests {
|
||||||
// Test creating a new user message
|
// Test creating a new user message
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.push_user_content_block(
|
thread.push_user_content_block(
|
||||||
|
None,
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
annotations: None,
|
annotations: None,
|
||||||
text: "Hello, ".to_string(),
|
text: "Hello, ".to_string(),
|
||||||
|
@ -1463,6 +1589,7 @@ mod tests {
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
assert_eq!(thread.entries.len(), 1);
|
assert_eq!(thread.entries.len(), 1);
|
||||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
||||||
|
assert_eq!(user_msg.id, None);
|
||||||
assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
|
assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected UserMessage");
|
panic!("Expected UserMessage");
|
||||||
|
@ -1470,8 +1597,10 @@ mod tests {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Test appending to existing user message
|
// Test appending to existing user message
|
||||||
|
let message_1_id = UserMessageId::new();
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.push_user_content_block(
|
thread.push_user_content_block(
|
||||||
|
Some(message_1_id.clone()),
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
annotations: None,
|
annotations: None,
|
||||||
text: "world!".to_string(),
|
text: "world!".to_string(),
|
||||||
|
@ -1483,6 +1612,7 @@ mod tests {
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
assert_eq!(thread.entries.len(), 1);
|
assert_eq!(thread.entries.len(), 1);
|
||||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
||||||
|
assert_eq!(user_msg.id, Some(message_1_id));
|
||||||
assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
|
assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected UserMessage");
|
panic!("Expected UserMessage");
|
||||||
|
@ -1501,8 +1631,10 @@ mod tests {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let message_2_id = UserMessageId::new();
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.push_user_content_block(
|
thread.push_user_content_block(
|
||||||
|
Some(message_2_id.clone()),
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
annotations: None,
|
annotations: None,
|
||||||
text: "New user message".to_string(),
|
text: "New user message".to_string(),
|
||||||
|
@ -1514,6 +1646,7 @@ mod tests {
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
assert_eq!(thread.entries.len(), 3);
|
assert_eq!(thread.entries.len(), 3);
|
||||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
|
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
|
||||||
|
assert_eq!(user_msg.id, Some(message_2_id));
|
||||||
assert_eq!(user_msg.content.to_markdown(cx), "New user message");
|
assert_eq!(user_msg.content.to_markdown(cx), "New user message");
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected UserMessage at index 2");
|
panic!("Expected UserMessage at index 2");
|
||||||
|
@ -1830,6 +1963,180 @@ mod tests {
|
||||||
assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
|
assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test(iterations = 10)]
|
||||||
|
async fn test_checkpoints(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
let fs = FakeFs::new(cx.background_executor.clone());
|
||||||
|
fs.insert_tree(
|
||||||
|
path!("/test"),
|
||||||
|
json!({
|
||||||
|
".git": {}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||||
|
|
||||||
|
let simulate_changes = Arc::new(AtomicBool::new(true));
|
||||||
|
let next_filename = Arc::new(AtomicUsize::new(0));
|
||||||
|
let connection = Rc::new(FakeAgentConnection::new().on_user_message({
|
||||||
|
let simulate_changes = simulate_changes.clone();
|
||||||
|
let next_filename = next_filename.clone();
|
||||||
|
let fs = fs.clone();
|
||||||
|
move |request, thread, mut cx| {
|
||||||
|
let fs = fs.clone();
|
||||||
|
let simulate_changes = simulate_changes.clone();
|
||||||
|
let next_filename = next_filename.clone();
|
||||||
|
async move {
|
||||||
|
if simulate_changes.load(SeqCst) {
|
||||||
|
let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
|
||||||
|
fs.write(Path::new(&filename), b"").await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let acp::ContentBlock::Text(content) = &request.prompt[0] else {
|
||||||
|
panic!("expected text content block");
|
||||||
|
};
|
||||||
|
thread.update(&mut cx, |thread, cx| {
|
||||||
|
thread
|
||||||
|
.handle_session_update(
|
||||||
|
acp::SessionUpdate::AgentMessageChunk {
|
||||||
|
content: content.text.to_uppercase().into(),
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
})?;
|
||||||
|
Ok(acp::PromptResponse {
|
||||||
|
stop_reason: acp::StopReason::EndTurn,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
.boxed_local()
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
let thread = connection
|
||||||
|
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
thread.read_with(cx, |thread, cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(cx),
|
||||||
|
indoc! {"
|
||||||
|
## User (checkpoint)
|
||||||
|
|
||||||
|
Lorem
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
LOREM
|
||||||
|
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
});
|
||||||
|
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
|
||||||
|
|
||||||
|
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
thread.read_with(cx, |thread, cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(cx),
|
||||||
|
indoc! {"
|
||||||
|
## User (checkpoint)
|
||||||
|
|
||||||
|
Lorem
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
LOREM
|
||||||
|
|
||||||
|
## User (checkpoint)
|
||||||
|
|
||||||
|
ipsum
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
IPSUM
|
||||||
|
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
});
|
||||||
|
assert_eq!(
|
||||||
|
fs.files(),
|
||||||
|
vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Checkpoint isn't stored when there are no changes.
|
||||||
|
simulate_changes.store(false, SeqCst);
|
||||||
|
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
thread.read_with(cx, |thread, cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(cx),
|
||||||
|
indoc! {"
|
||||||
|
## User (checkpoint)
|
||||||
|
|
||||||
|
Lorem
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
LOREM
|
||||||
|
|
||||||
|
## User (checkpoint)
|
||||||
|
|
||||||
|
ipsum
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
IPSUM
|
||||||
|
|
||||||
|
## User
|
||||||
|
|
||||||
|
dolor
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
DOLOR
|
||||||
|
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
});
|
||||||
|
assert_eq!(
|
||||||
|
fs.files(),
|
||||||
|
vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Rewinding the conversation truncates the history and restores the checkpoint.
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, cx| {
|
||||||
|
let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
|
||||||
|
panic!("unexpected entries {:?}", thread.entries)
|
||||||
|
};
|
||||||
|
thread.rewind(message.id.clone().unwrap(), cx)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
thread.read_with(cx, |thread, cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(cx),
|
||||||
|
indoc! {"
|
||||||
|
## User (checkpoint)
|
||||||
|
|
||||||
|
Lorem
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
LOREM
|
||||||
|
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
});
|
||||||
|
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
|
||||||
|
}
|
||||||
|
|
||||||
async fn run_until_first_tool_call(
|
async fn run_until_first_tool_call(
|
||||||
thread: &Entity<AcpThread>,
|
thread: &Entity<AcpThread>,
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
|
@ -1938,6 +2245,7 @@ mod tests {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
|
_id: Option<UserMessageId>,
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||||
|
@ -1966,5 +2274,25 @@ mod tests {
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn session_editor(
|
||||||
|
&self,
|
||||||
|
session_id: &acp::SessionId,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||||
|
Some(Rc::new(FakeAgentSessionEditor {
|
||||||
|
_session_id: session_id.clone(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FakeAgentSessionEditor {
|
||||||
|
_session_id: acp::SessionId,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentSessionEditor for FakeAgentSessionEditor {
|
||||||
|
fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,21 @@
|
||||||
use std::{error::Error, fmt, path::Path, rc::Rc};
|
use crate::AcpThread;
|
||||||
|
|
||||||
use agent_client_protocol::{self as acp};
|
use agent_client_protocol::{self as acp};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use collections::IndexMap;
|
use collections::IndexMap;
|
||||||
use gpui::{AsyncApp, Entity, SharedString, Task};
|
use gpui::{AsyncApp, Entity, SharedString, Task};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||||
use ui::{App, IconName};
|
use ui::{App, IconName};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::AcpThread;
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
|
pub struct UserMessageId(Arc<str>);
|
||||||
|
|
||||||
|
impl UserMessageId {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self(Uuid::new_v4().to_string().into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait AgentConnection {
|
pub trait AgentConnection {
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
|
@ -21,11 +29,23 @@ pub trait AgentConnection {
|
||||||
|
|
||||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||||
|
|
||||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
|
fn prompt(
|
||||||
-> Task<Result<acp::PromptResponse>>;
|
&self,
|
||||||
|
user_message_id: Option<UserMessageId>,
|
||||||
|
params: acp::PromptRequest,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Task<Result<acp::PromptResponse>>;
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||||
|
|
||||||
|
fn session_editor(
|
||||||
|
&self,
|
||||||
|
_session_id: &acp::SessionId,
|
||||||
|
_cx: &mut App,
|
||||||
|
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||||
///
|
///
|
||||||
/// If the agent does not support model selection, returns [None].
|
/// If the agent does not support model selection, returns [None].
|
||||||
|
@ -35,6 +55,10 @@ pub trait AgentConnection {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait AgentSessionEditor {
|
||||||
|
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct AuthRequired;
|
pub struct AuthRequired;
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
||||||
use crate::{
|
use crate::{
|
||||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
|
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
|
||||||
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
|
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
|
||||||
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
|
ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
|
||||||
|
WebSearchTool,
|
||||||
};
|
};
|
||||||
use acp_thread::AgentModelSelector;
|
use acp_thread::AgentModelSelector;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
|
@ -637,9 +638,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
|
id: Option<acp_thread::UserMessageId>,
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
|
let id = id.expect("UserMessageId is required");
|
||||||
let session_id = params.session_id.clone();
|
let session_id = params.session_id.clone();
|
||||||
let agent = self.0.clone();
|
let agent = self.0.clone();
|
||||||
log::info!("Received prompt request for session: {}", session_id);
|
log::info!("Received prompt request for session: {}", session_id);
|
||||||
|
@ -660,13 +663,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
})?;
|
})?;
|
||||||
log::debug!("Found session for: {}", session_id);
|
log::debug!("Found session for: {}", session_id);
|
||||||
|
|
||||||
let message: Vec<MessageContent> = params
|
let content: Vec<UserMessageContent> = params
|
||||||
.prompt
|
.prompt
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(Into::into)
|
.map(Into::into)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
log::info!("Converted prompt to message: {} chars", message.len());
|
log::info!("Converted prompt to message: {} chars", content.len());
|
||||||
log::debug!("Message content: {:?}", message);
|
log::debug!("Message id: {:?}", id);
|
||||||
|
log::debug!("Message content: {:?}", content);
|
||||||
|
|
||||||
// Get model using the ModelSelector capability (always available for agent2)
|
// Get model using the ModelSelector capability (always available for agent2)
|
||||||
// Get the selected model from the thread directly
|
// Get the selected model from the thread directly
|
||||||
|
@ -674,7 +678,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
|
|
||||||
// Send to thread
|
// Send to thread
|
||||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||||
let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
|
let mut response_stream =
|
||||||
|
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
|
||||||
|
|
||||||
// Handle response stream and forward to session.acp_thread
|
// Handle response stream and forward to session.acp_thread
|
||||||
while let Some(result) = response_stream.next().await {
|
while let Some(result) = response_stream.next().await {
|
||||||
|
@ -768,6 +773,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn session_editor(
|
||||||
|
&self,
|
||||||
|
session_id: &agent_client_protocol::SessionId,
|
||||||
|
cx: &mut App,
|
||||||
|
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
|
||||||
|
self.0.update(cx, |agent, _cx| {
|
||||||
|
agent
|
||||||
|
.sessions
|
||||||
|
.get(session_id)
|
||||||
|
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NativeAgentSessionEditor(Entity<Thread>);
|
||||||
|
|
||||||
|
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||||
|
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
||||||
|
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::MessageContent;
|
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
|
||||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList};
|
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol::{self as acp};
|
use agent_client_protocol::{self as acp};
|
||||||
use agent_settings::AgentProfileId;
|
use agent_settings::AgentProfileId;
|
||||||
|
@ -38,15 +37,19 @@ async fn test_echo(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let events = thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send("Testing: Reply with 'Hello'", cx)
|
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
.await;
|
.await;
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.messages().last().unwrap().content,
|
thread.last_message().unwrap().to_markdown(),
|
||||||
vec![MessageContent::Text("Hello".to_string())]
|
indoc! {"
|
||||||
);
|
## Assistant
|
||||||
|
|
||||||
|
Hello
|
||||||
|
"}
|
||||||
|
)
|
||||||
});
|
});
|
||||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||||
}
|
}
|
||||||
|
@ -59,12 +62,13 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
||||||
let events = thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(
|
thread.send(
|
||||||
indoc! {"
|
UserMessageId::new(),
|
||||||
|
[indoc! {"
|
||||||
Testing:
|
Testing:
|
||||||
|
|
||||||
Generate a thinking step where you just think the word 'Think',
|
Generate a thinking step where you just think the word 'Think',
|
||||||
and have your final answer be 'Hello'
|
and have your final answer be 'Hello'
|
||||||
"},
|
"}],
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
@ -72,9 +76,10 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
||||||
.await;
|
.await;
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.messages().last().unwrap().to_markdown(),
|
thread.last_message().unwrap().to_markdown(),
|
||||||
indoc! {"
|
indoc! {"
|
||||||
## assistant
|
## Assistant
|
||||||
|
|
||||||
<think>Think</think>
|
<think>Think</think>
|
||||||
Hello
|
Hello
|
||||||
"}
|
"}
|
||||||
|
@ -95,7 +100,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
project_context.borrow_mut().shell = "test-shell".into();
|
project_context.borrow_mut().shell = "test-shell".into();
|
||||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||||
thread.update(cx, |thread, cx| thread.send("abc", cx));
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let mut pending_completions = fake_model.pending_completions();
|
let mut pending_completions = fake_model.pending_completions();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -132,7 +139,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(
|
thread.send(
|
||||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
UserMessageId::new(),
|
||||||
|
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
@ -146,7 +154,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||||
thread.remove_tool(&AgentTool::name(&EchoTool));
|
thread.remove_tool(&AgentTool::name(&EchoTool));
|
||||||
thread.add_tool(DelayTool);
|
thread.add_tool(DelayTool);
|
||||||
thread.send(
|
thread.send(
|
||||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
UserMessageId::new(),
|
||||||
|
[
|
||||||
|
"Now call the delay tool with 200ms.",
|
||||||
|
"When the timer goes off, then you echo the output of the tool.",
|
||||||
|
],
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
@ -156,13 +168,14 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
assert!(
|
assert!(
|
||||||
thread
|
thread
|
||||||
.messages()
|
.last_message()
|
||||||
.last()
|
.unwrap()
|
||||||
|
.as_agent_message()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.content
|
.content
|
||||||
.iter()
|
.iter()
|
||||||
.any(|content| {
|
.any(|content| {
|
||||||
if let MessageContent::Text(text) = content {
|
if let AgentMessageContent::Text(text) = content {
|
||||||
text.contains("Ding")
|
text.contains("Ding")
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
|
@ -182,7 +195,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||||
// Test a tool call that's likely to complete *before* streaming stops.
|
// Test a tool call that's likely to complete *before* streaming stops.
|
||||||
let mut events = thread.update(cx, |thread, cx| {
|
let mut events = thread.update(cx, |thread, cx| {
|
||||||
thread.add_tool(WordListTool);
|
thread.add_tool(WordListTool);
|
||||||
thread.send("Test the word_list tool.", cx)
|
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut saw_partial_tool_use = false;
|
let mut saw_partial_tool_use = false;
|
||||||
|
@ -190,8 +203,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||||
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
// Look for a tool use in the thread's last message
|
// Look for a tool use in the thread's last message
|
||||||
let last_content = thread.messages().last().unwrap().content.last().unwrap();
|
let message = thread.last_message().unwrap();
|
||||||
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
let agent_message = message.as_agent_message().unwrap();
|
||||||
|
let last_content = agent_message.content.last().unwrap();
|
||||||
|
if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
|
||||||
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
||||||
if tool_call.status == acp::ToolCallStatus::Pending {
|
if tool_call.status == acp::ToolCallStatus::Pending {
|
||||||
if !last_tool_use.is_input_complete
|
if !last_tool_use.is_input_complete
|
||||||
|
@ -229,7 +244,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let mut events = thread.update(cx, |thread, cx| {
|
let mut events = thread.update(cx, |thread, cx| {
|
||||||
thread.add_tool(ToolRequiringPermission);
|
thread.add_tool(ToolRequiringPermission);
|
||||||
thread.send("abc", cx)
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
@ -357,7 +372,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
|
let mut events = thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||||
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
LanguageModelToolUse {
|
LanguageModelToolUse {
|
||||||
|
@ -449,7 +466,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(DelayTool);
|
thread.add_tool(DelayTool);
|
||||||
thread.send(
|
thread.send(
|
||||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
UserMessageId::new(),
|
||||||
|
[
|
||||||
|
"Call the delay tool twice in the same message.",
|
||||||
|
"Once with 100ms. Once with 300ms.",
|
||||||
|
"When both timers are complete, describe the outputs.",
|
||||||
|
],
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
@ -460,12 +482,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||||
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
|
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
|
||||||
|
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
let last_message = thread.messages().last().unwrap();
|
let last_message = thread.last_message().unwrap();
|
||||||
let text = last_message
|
let agent_message = last_message.as_agent_message().unwrap();
|
||||||
|
let text = agent_message
|
||||||
.content
|
.content
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|content| {
|
.filter_map(|content| {
|
||||||
if let MessageContent::Text(text) = content {
|
if let AgentMessageContent::Text(text) = content {
|
||||||
Some(text.as_str())
|
Some(text.as_str())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -521,7 +544,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
||||||
// Test that test-1 profile (default) has echo and delay tools
|
// Test that test-1 profile (default) has echo and delay tools
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.set_profile(AgentProfileId("test-1".into()));
|
thread.set_profile(AgentProfileId("test-1".into()));
|
||||||
thread.send("test", cx);
|
thread.send(UserMessageId::new(), ["test"], cx);
|
||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
@ -539,7 +562,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
||||||
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.set_profile(AgentProfileId("test-2".into()));
|
thread.set_profile(AgentProfileId("test-2".into()));
|
||||||
thread.send("test2", cx)
|
thread.send(UserMessageId::new(), ["test2"], cx)
|
||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let mut pending_completions = fake_model.pending_completions();
|
let mut pending_completions = fake_model.pending_completions();
|
||||||
|
@ -562,7 +585,8 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
thread.add_tool(InfiniteTool);
|
thread.add_tool(InfiniteTool);
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(
|
thread.send(
|
||||||
"Call the echo tool and then call the infinite tool, then explain their output",
|
UserMessageId::new(),
|
||||||
|
["Call the echo tool, then call the infinite tool, then explain their output"],
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -607,14 +631,20 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
// Ensure we can still send a new message after cancellation.
|
// Ensure we can still send a new message after cancellation.
|
||||||
let events = thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send("Testing: reply with 'Hello' then stop.", cx)
|
thread.send(
|
||||||
|
UserMessageId::new(),
|
||||||
|
["Testing: reply with 'Hello' then stop."],
|
||||||
|
cx,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.await;
|
.await;
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
|
let message = thread.last_message().unwrap();
|
||||||
|
let agent_message = message.as_agent_message().unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.messages().last().unwrap().content,
|
agent_message.content,
|
||||||
vec![MessageContent::Text("Hello".to_string())]
|
vec![AgentMessageContent::Text("Hello".to_string())]
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||||
|
@ -625,13 +655,16 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
|
let events = thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Hello"], cx)
|
||||||
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.to_markdown(),
|
thread.to_markdown(),
|
||||||
indoc! {"
|
indoc! {"
|
||||||
## user
|
## User
|
||||||
|
|
||||||
Hello
|
Hello
|
||||||
"}
|
"}
|
||||||
);
|
);
|
||||||
|
@ -643,9 +676,12 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.to_markdown(),
|
thread.to_markdown(),
|
||||||
indoc! {"
|
indoc! {"
|
||||||
## user
|
## User
|
||||||
|
|
||||||
Hello
|
Hello
|
||||||
## assistant
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
Hey!
|
Hey!
|
||||||
"}
|
"}
|
||||||
);
|
);
|
||||||
|
@ -661,6 +697,85 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_truncate(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let message_id = UserMessageId::new();
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(message_id.clone(), ["Hello"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
thread.read_with(cx, |thread, _| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
Hello
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||||
|
cx.run_until_parked();
|
||||||
|
thread.read_with(cx, |thread, _| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
Hello
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Hey!
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, _cx| thread.truncate(message_id))
|
||||||
|
.unwrap();
|
||||||
|
cx.run_until_parked();
|
||||||
|
thread.read_with(cx, |thread, _| {
|
||||||
|
assert_eq!(thread.to_markdown(), "");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Ensure we can still send a new message after truncation.
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Hi"], cx)
|
||||||
|
});
|
||||||
|
thread.update(cx, |thread, _cx| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
Hi
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
||||||
|
cx.run_until_parked();
|
||||||
|
thread.read_with(cx, |thread, _| {
|
||||||
|
assert_eq!(
|
||||||
|
thread.to_markdown(),
|
||||||
|
indoc! {"
|
||||||
|
## User
|
||||||
|
|
||||||
|
Hi
|
||||||
|
|
||||||
|
## Assistant
|
||||||
|
|
||||||
|
Ahoy!
|
||||||
|
"}
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||||
cx.update(settings::init);
|
cx.update(settings::init);
|
||||||
|
@ -774,6 +889,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||||
let result = cx
|
let result = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
connection.prompt(
|
connection.prompt(
|
||||||
|
Some(acp_thread::UserMessageId::new()),
|
||||||
acp::PromptRequest {
|
acp::PromptRequest {
|
||||||
session_id: session_id.clone(),
|
session_id: session_id.clone(),
|
||||||
prompt: vec!["ghi".into()],
|
prompt: vec!["ghi".into()],
|
||||||
|
@ -796,7 +912,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||||
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
|
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
|
let mut events = thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Think"], cx)
|
||||||
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
// Simulate streaming partial input.
|
// Simulate streaming partial input.
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
||||||
use acp_thread::MentionUri;
|
use acp_thread::{MentionUri, UserMessageId};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::{AgentProfileId, AgentSettings};
|
use agent_settings::{AgentProfileId, AgentSettings};
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use assistant_tool::adapt_schema_to_format;
|
use assistant_tool::adapt_schema_to_format;
|
||||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||||
use collections::HashMap;
|
use collections::IndexMap;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use futures::{
|
use futures::{
|
||||||
channel::{mpsc, oneshot},
|
channel::{mpsc, oneshot},
|
||||||
|
@ -19,7 +19,6 @@ use language_model::{
|
||||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||||
};
|
};
|
||||||
use log;
|
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_store::ProjectContext;
|
use prompt_store::ProjectContext;
|
||||||
use schemars::{JsonSchema, Schema};
|
use schemars::{JsonSchema, Schema};
|
||||||
|
@ -30,49 +29,199 @@ use std::fmt::Write;
|
||||||
use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
|
use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
|
||||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct AgentMessage {
|
pub enum Message {
|
||||||
pub role: Role,
|
User(UserMessage),
|
||||||
pub content: Vec<MessageContent>,
|
Agent(AgentMessage),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Message {
|
||||||
|
pub fn as_agent_message(&self) -> Option<&AgentMessage> {
|
||||||
|
match self {
|
||||||
|
Message::Agent(agent_message) => Some(agent_message),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_markdown(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Message::User(message) => message.to_markdown(),
|
||||||
|
Message::Agent(message) => message.to_markdown(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub enum MessageContent {
|
pub struct UserMessage {
|
||||||
|
pub id: UserMessageId,
|
||||||
|
pub content: Vec<UserMessageContent>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum UserMessageContent {
|
||||||
Text(String),
|
Text(String),
|
||||||
Thinking {
|
Mention { uri: MentionUri, content: String },
|
||||||
text: String,
|
|
||||||
signature: Option<String>,
|
|
||||||
},
|
|
||||||
Mention {
|
|
||||||
uri: MentionUri,
|
|
||||||
content: String,
|
|
||||||
},
|
|
||||||
RedactedThinking(String),
|
|
||||||
Image(LanguageModelImage),
|
Image(LanguageModelImage),
|
||||||
ToolUse(LanguageModelToolUse),
|
}
|
||||||
ToolResult(LanguageModelToolResult),
|
|
||||||
|
impl UserMessage {
|
||||||
|
pub fn to_markdown(&self) -> String {
|
||||||
|
let mut markdown = String::from("## User\n\n");
|
||||||
|
|
||||||
|
for content in &self.content {
|
||||||
|
match content {
|
||||||
|
UserMessageContent::Text(text) => {
|
||||||
|
markdown.push_str(text);
|
||||||
|
markdown.push('\n');
|
||||||
|
}
|
||||||
|
UserMessageContent::Image(_) => {
|
||||||
|
markdown.push_str("<image />\n");
|
||||||
|
}
|
||||||
|
UserMessageContent::Mention { uri, content } => {
|
||||||
|
if !content.is_empty() {
|
||||||
|
markdown.push_str(&format!("{}\n\n{}\n", uri.to_link(), content));
|
||||||
|
} else {
|
||||||
|
markdown.push_str(&format!("{}\n", uri.to_link()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
markdown
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_request(&self) -> LanguageModelRequestMessage {
|
||||||
|
let mut message = LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: Vec::with_capacity(self.content.len()),
|
||||||
|
cache: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
const OPEN_CONTEXT: &str = "<context>\n\
|
||||||
|
The following items were attached by the user. \
|
||||||
|
They are up-to-date and don't need to be re-read.\n\n";
|
||||||
|
|
||||||
|
const OPEN_FILES_TAG: &str = "<files>";
|
||||||
|
const OPEN_SYMBOLS_TAG: &str = "<symbols>";
|
||||||
|
const OPEN_THREADS_TAG: &str = "<threads>";
|
||||||
|
const OPEN_RULES_TAG: &str =
|
||||||
|
"<rules>\nThe user has specified the following rules that should be applied:\n";
|
||||||
|
|
||||||
|
let mut file_context = OPEN_FILES_TAG.to_string();
|
||||||
|
let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
|
||||||
|
let mut thread_context = OPEN_THREADS_TAG.to_string();
|
||||||
|
let mut rules_context = OPEN_RULES_TAG.to_string();
|
||||||
|
|
||||||
|
for chunk in &self.content {
|
||||||
|
let chunk = match chunk {
|
||||||
|
UserMessageContent::Text(text) => {
|
||||||
|
language_model::MessageContent::Text(text.clone())
|
||||||
|
}
|
||||||
|
UserMessageContent::Image(value) => {
|
||||||
|
language_model::MessageContent::Image(value.clone())
|
||||||
|
}
|
||||||
|
UserMessageContent::Mention { uri, content } => {
|
||||||
|
match uri {
|
||||||
|
MentionUri::File(path) | MentionUri::Symbol(path, _) => {
|
||||||
|
write!(
|
||||||
|
&mut symbol_context,
|
||||||
|
"\n{}",
|
||||||
|
MarkdownCodeBlock {
|
||||||
|
tag: &codeblock_tag(&path),
|
||||||
|
text: &content.to_string(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
MentionUri::Thread(_session_id) => {
|
||||||
|
write!(&mut thread_context, "\n{}\n", content).ok();
|
||||||
|
}
|
||||||
|
MentionUri::Rule(_user_prompt_id) => {
|
||||||
|
write!(
|
||||||
|
&mut rules_context,
|
||||||
|
"\n{}",
|
||||||
|
MarkdownCodeBlock {
|
||||||
|
tag: "",
|
||||||
|
text: &content
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
language_model::MessageContent::Text(uri.to_link())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
message.content.push(chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
let len_before_context = message.content.len();
|
||||||
|
|
||||||
|
if file_context.len() > OPEN_FILES_TAG.len() {
|
||||||
|
file_context.push_str("</files>\n");
|
||||||
|
message
|
||||||
|
.content
|
||||||
|
.push(language_model::MessageContent::Text(file_context));
|
||||||
|
}
|
||||||
|
|
||||||
|
if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
|
||||||
|
symbol_context.push_str("</symbols>\n");
|
||||||
|
message
|
||||||
|
.content
|
||||||
|
.push(language_model::MessageContent::Text(symbol_context));
|
||||||
|
}
|
||||||
|
|
||||||
|
if thread_context.len() > OPEN_THREADS_TAG.len() {
|
||||||
|
thread_context.push_str("</threads>\n");
|
||||||
|
message
|
||||||
|
.content
|
||||||
|
.push(language_model::MessageContent::Text(thread_context));
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules_context.len() > OPEN_RULES_TAG.len() {
|
||||||
|
rules_context.push_str("</user_rules>\n");
|
||||||
|
message
|
||||||
|
.content
|
||||||
|
.push(language_model::MessageContent::Text(rules_context));
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.content.len() > len_before_context {
|
||||||
|
message.content.insert(
|
||||||
|
len_before_context,
|
||||||
|
language_model::MessageContent::Text(OPEN_CONTEXT.into()),
|
||||||
|
);
|
||||||
|
message
|
||||||
|
.content
|
||||||
|
.push(language_model::MessageContent::Text("</context>".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
message
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentMessage {
|
impl AgentMessage {
|
||||||
pub fn to_markdown(&self) -> String {
|
pub fn to_markdown(&self) -> String {
|
||||||
let mut markdown = format!("## {}\n", self.role);
|
let mut markdown = String::from("## Assistant\n\n");
|
||||||
|
|
||||||
for content in &self.content {
|
for content in &self.content {
|
||||||
match content {
|
match content {
|
||||||
MessageContent::Text(text) => {
|
AgentMessageContent::Text(text) => {
|
||||||
markdown.push_str(text);
|
markdown.push_str(text);
|
||||||
markdown.push('\n');
|
markdown.push('\n');
|
||||||
}
|
}
|
||||||
MessageContent::Thinking { text, .. } => {
|
AgentMessageContent::Thinking { text, .. } => {
|
||||||
markdown.push_str("<think>");
|
markdown.push_str("<think>");
|
||||||
markdown.push_str(text);
|
markdown.push_str(text);
|
||||||
markdown.push_str("</think>\n");
|
markdown.push_str("</think>\n");
|
||||||
}
|
}
|
||||||
MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
|
AgentMessageContent::RedactedThinking(_) => {
|
||||||
MessageContent::Image(_) => {
|
markdown.push_str("<redacted_thinking />\n")
|
||||||
|
}
|
||||||
|
AgentMessageContent::Image(_) => {
|
||||||
markdown.push_str("<image />\n");
|
markdown.push_str("<image />\n");
|
||||||
}
|
}
|
||||||
MessageContent::ToolUse(tool_use) => {
|
AgentMessageContent::ToolUse(tool_use) => {
|
||||||
markdown.push_str(&format!(
|
markdown.push_str(&format!(
|
||||||
"**Tool Use**: {} (ID: {})\n",
|
"**Tool Use**: {} (ID: {})\n",
|
||||||
tool_use.name, tool_use.id
|
tool_use.name, tool_use.id
|
||||||
|
@ -85,41 +234,106 @@ impl AgentMessage {
|
||||||
}
|
}
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
MessageContent::ToolResult(tool_result) => {
|
}
|
||||||
markdown.push_str(&format!(
|
}
|
||||||
"**Tool Result**: {} (ID: {})\n\n",
|
|
||||||
tool_result.tool_name, tool_result.tool_use_id
|
|
||||||
));
|
|
||||||
if tool_result.is_error {
|
|
||||||
markdown.push_str("**ERROR:**\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
match &tool_result.content {
|
for tool_result in self.tool_results.values() {
|
||||||
LanguageModelToolResultContent::Text(text) => {
|
markdown.push_str(&format!(
|
||||||
writeln!(markdown, "{text}\n").ok();
|
"**Tool Result**: {} (ID: {})\n\n",
|
||||||
}
|
tool_result.tool_name, tool_result.tool_use_id
|
||||||
LanguageModelToolResultContent::Image(_) => {
|
));
|
||||||
writeln!(markdown, "<image />\n").ok();
|
if tool_result.is_error {
|
||||||
}
|
markdown.push_str("**ERROR:**\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(output) = tool_result.output.as_ref() {
|
match &tool_result.content {
|
||||||
writeln!(
|
LanguageModelToolResultContent::Text(text) => {
|
||||||
markdown,
|
writeln!(markdown, "{text}\n").ok();
|
||||||
"**Debug Output**:\n\n```json\n{}\n```\n",
|
|
||||||
serde_json::to_string_pretty(output).unwrap()
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
MessageContent::Mention { uri, .. } => {
|
LanguageModelToolResultContent::Image(_) => {
|
||||||
write!(markdown, "{}", uri.to_link()).ok();
|
writeln!(markdown, "<image />\n").ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(output) = tool_result.output.as_ref() {
|
||||||
|
writeln!(
|
||||||
|
markdown,
|
||||||
|
"**Debug Output**:\n\n```json\n{}\n```\n",
|
||||||
|
serde_json::to_string_pretty(output).unwrap()
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
markdown
|
markdown
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
|
||||||
|
let mut content = Vec::with_capacity(self.content.len());
|
||||||
|
for chunk in &self.content {
|
||||||
|
let chunk = match chunk {
|
||||||
|
AgentMessageContent::Text(text) => {
|
||||||
|
language_model::MessageContent::Text(text.clone())
|
||||||
|
}
|
||||||
|
AgentMessageContent::Thinking { text, signature } => {
|
||||||
|
language_model::MessageContent::Thinking {
|
||||||
|
text: text.clone(),
|
||||||
|
signature: signature.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AgentMessageContent::RedactedThinking(value) => {
|
||||||
|
language_model::MessageContent::RedactedThinking(value.clone())
|
||||||
|
}
|
||||||
|
AgentMessageContent::ToolUse(value) => {
|
||||||
|
language_model::MessageContent::ToolUse(value.clone())
|
||||||
|
}
|
||||||
|
AgentMessageContent::Image(value) => {
|
||||||
|
language_model::MessageContent::Image(value.clone())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
content.push(chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut messages = vec![LanguageModelRequestMessage {
|
||||||
|
role: Role::Assistant,
|
||||||
|
content,
|
||||||
|
cache: false,
|
||||||
|
}];
|
||||||
|
|
||||||
|
if !self.tool_results.is_empty() {
|
||||||
|
let mut tool_results = Vec::with_capacity(self.tool_results.len());
|
||||||
|
for tool_result in self.tool_results.values() {
|
||||||
|
tool_results.push(language_model::MessageContent::ToolResult(
|
||||||
|
tool_result.clone(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
messages.push(LanguageModelRequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: tool_results,
|
||||||
|
cache: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct AgentMessage {
|
||||||
|
pub content: Vec<AgentMessageContent>,
|
||||||
|
pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum AgentMessageContent {
|
||||||
|
Text(String),
|
||||||
|
Thinking {
|
||||||
|
text: String,
|
||||||
|
signature: Option<String>,
|
||||||
|
},
|
||||||
|
RedactedThinking(String),
|
||||||
|
Image(LanguageModelImage),
|
||||||
|
ToolUse(LanguageModelToolUse),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -140,13 +354,13 @@ pub struct ToolCallAuthorization {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Thread {
|
pub struct Thread {
|
||||||
messages: Vec<AgentMessage>,
|
messages: Vec<Message>,
|
||||||
completion_mode: CompletionMode,
|
completion_mode: CompletionMode,
|
||||||
/// Holds the task that handles agent interaction until the end of the turn.
|
/// Holds the task that handles agent interaction until the end of the turn.
|
||||||
/// Survives across multiple requests as the model performs tool calls and
|
/// Survives across multiple requests as the model performs tool calls and
|
||||||
/// we run tools, report their results.
|
/// we run tools, report their results.
|
||||||
running_turn: Option<Task<()>>,
|
running_turn: Option<Task<()>>,
|
||||||
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
|
pending_agent_message: Option<AgentMessage>,
|
||||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
context_server_registry: Entity<ContextServerRegistry>,
|
context_server_registry: Entity<ContextServerRegistry>,
|
||||||
profile_id: AgentProfileId,
|
profile_id: AgentProfileId,
|
||||||
|
@ -172,7 +386,7 @@ impl Thread {
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
completion_mode: CompletionMode::Normal,
|
completion_mode: CompletionMode::Normal,
|
||||||
running_turn: None,
|
running_turn: None,
|
||||||
pending_tool_uses: HashMap::default(),
|
pending_agent_message: None,
|
||||||
tools: BTreeMap::default(),
|
tools: BTreeMap::default(),
|
||||||
context_server_registry,
|
context_server_registry,
|
||||||
profile_id,
|
profile_id,
|
||||||
|
@ -196,8 +410,13 @@ impl Thread {
|
||||||
self.completion_mode = mode;
|
self.completion_mode = mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn messages(&self) -> &[AgentMessage] {
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
&self.messages
|
pub fn last_message(&self) -> Option<Message> {
|
||||||
|
if let Some(message) = self.pending_agent_message.clone() {
|
||||||
|
Some(Message::Agent(message))
|
||||||
|
} else {
|
||||||
|
self.messages.last().cloned()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
||||||
|
@ -213,35 +432,36 @@ impl Thread {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cancel(&mut self) {
|
pub fn cancel(&mut self) {
|
||||||
|
// TODO: do we need to emit a stop::cancel for ACP?
|
||||||
self.running_turn.take();
|
self.running_turn.take();
|
||||||
|
self.flush_pending_agent_message();
|
||||||
|
}
|
||||||
|
|
||||||
let tool_results = self
|
pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
|
||||||
.pending_tool_uses
|
self.cancel();
|
||||||
.drain()
|
let Some(position) = self.messages.iter().position(
|
||||||
.map(|(tool_use_id, tool_use)| {
|
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
|
||||||
MessageContent::ToolResult(LanguageModelToolResult {
|
) else {
|
||||||
tool_use_id,
|
return Err(anyhow!("Message not found"));
|
||||||
tool_name: tool_use.name.clone(),
|
};
|
||||||
is_error: true,
|
self.messages.truncate(position);
|
||||||
content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
|
Ok(())
|
||||||
output: None,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
self.last_user_message().content.extend(tool_results);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||||
pub fn send(
|
pub fn send<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
content: impl Into<UserMessage>,
|
message_id: UserMessageId,
|
||||||
|
content: impl IntoIterator<Item = T>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
|
||||||
let content = content.into().0;
|
where
|
||||||
|
T: Into<UserMessageContent>,
|
||||||
|
{
|
||||||
let model = self.selected_model.clone();
|
let model = self.selected_model.clone();
|
||||||
|
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
|
||||||
log::info!("Thread::send called with model: {:?}", model.name());
|
log::info!("Thread::send called with model: {:?}", model.name());
|
||||||
log::debug!("Thread::send content: {:?}", content);
|
log::debug!("Thread::send content: {:?}", content);
|
||||||
|
|
||||||
|
@ -251,10 +471,10 @@ impl Thread {
|
||||||
let event_stream = AgentResponseEventStream(events_tx);
|
let event_stream = AgentResponseEventStream(events_tx);
|
||||||
|
|
||||||
let user_message_ix = self.messages.len();
|
let user_message_ix = self.messages.len();
|
||||||
self.messages.push(AgentMessage {
|
self.messages.push(Message::User(UserMessage {
|
||||||
role: Role::User,
|
id: message_id,
|
||||||
content,
|
content,
|
||||||
});
|
}));
|
||||||
log::info!("Total messages in thread: {}", self.messages.len());
|
log::info!("Total messages in thread: {}", self.messages.len());
|
||||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||||
log::info!("Starting agent turn execution");
|
log::info!("Starting agent turn execution");
|
||||||
|
@ -270,15 +490,11 @@ impl Thread {
|
||||||
thread.build_completion_request(completion_intent, cx)
|
thread.build_completion_request(completion_intent, cx)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// println!(
|
|
||||||
// "request: {}",
|
|
||||||
// serde_json::to_string_pretty(&request).unwrap()
|
|
||||||
// );
|
|
||||||
|
|
||||||
// Stream events, appending to messages and collecting up tool uses.
|
// Stream events, appending to messages and collecting up tool uses.
|
||||||
log::info!("Calling model.stream_completion");
|
log::info!("Calling model.stream_completion");
|
||||||
let mut events = model.stream_completion(request, cx).await?;
|
let mut events = model.stream_completion(request, cx).await?;
|
||||||
log::debug!("Stream completion started successfully");
|
log::debug!("Stream completion started successfully");
|
||||||
|
|
||||||
let mut tool_uses = FuturesUnordered::new();
|
let mut tool_uses = FuturesUnordered::new();
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
match event {
|
match event {
|
||||||
|
@ -286,6 +502,7 @@ impl Thread {
|
||||||
event_stream.send_stop(reason);
|
event_stream.send_stop(reason);
|
||||||
if reason == StopReason::Refusal {
|
if reason == StopReason::Refusal {
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
|
thread.pending_agent_message = None;
|
||||||
thread.messages.truncate(user_message_ix);
|
thread.messages.truncate(user_message_ix);
|
||||||
})?;
|
})?;
|
||||||
break 'outer;
|
break 'outer;
|
||||||
|
@ -338,15 +555,16 @@ impl Thread {
|
||||||
);
|
);
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, _cx| {
|
.update(cx, |thread, _cx| {
|
||||||
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
|
|
||||||
thread
|
thread
|
||||||
.last_user_message()
|
.pending_agent_message()
|
||||||
.content
|
.tool_results
|
||||||
.push(MessageContent::ToolResult(tool_result));
|
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||||
})
|
})
|
||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?;
|
||||||
|
|
||||||
completion_intent = CompletionIntent::ToolResults;
|
completion_intent = CompletionIntent::ToolResults;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,6 +572,10 @@ impl Thread {
|
||||||
}
|
}
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
thread
|
||||||
|
.update(cx, |thread, _cx| thread.flush_pending_agent_message())
|
||||||
|
.ok();
|
||||||
|
|
||||||
if let Err(error) = turn_result {
|
if let Err(error) = turn_result {
|
||||||
log::error!("Turn execution failed: {:?}", error);
|
log::error!("Turn execution failed: {:?}", error);
|
||||||
event_stream.send_error(error);
|
event_stream.send_error(error);
|
||||||
|
@ -364,7 +586,7 @@ impl Thread {
|
||||||
events_rx
|
events_rx
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_system_message(&self) -> AgentMessage {
|
pub fn build_system_message(&self) -> LanguageModelRequestMessage {
|
||||||
log::debug!("Building system message");
|
log::debug!("Building system message");
|
||||||
let prompt = SystemPromptTemplate {
|
let prompt = SystemPromptTemplate {
|
||||||
project: &self.project_context.borrow(),
|
project: &self.project_context.borrow(),
|
||||||
|
@ -374,9 +596,10 @@ impl Thread {
|
||||||
.context("failed to build system prompt")
|
.context("failed to build system prompt")
|
||||||
.expect("Invalid template");
|
.expect("Invalid template");
|
||||||
log::debug!("System message built");
|
log::debug!("System message built");
|
||||||
AgentMessage {
|
LanguageModelRequestMessage {
|
||||||
role: Role::System,
|
role: Role::System,
|
||||||
content: vec![prompt.as_str().into()],
|
content: vec![prompt.into()],
|
||||||
|
cache: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -394,10 +617,7 @@ impl Thread {
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
StartMessage { .. } => {
|
StartMessage { .. } => {
|
||||||
self.messages.push(AgentMessage {
|
self.messages.push(Message::Agent(AgentMessage::default()));
|
||||||
role: Role::Assistant,
|
|
||||||
content: Vec::new(),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
|
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
|
||||||
Thinking { text, signature } => {
|
Thinking { text, signature } => {
|
||||||
|
@ -435,11 +655,13 @@ impl Thread {
|
||||||
) {
|
) {
|
||||||
events_stream.send_text(&new_text);
|
events_stream.send_text(&new_text);
|
||||||
|
|
||||||
let last_message = self.last_assistant_message();
|
let last_message = self.pending_agent_message();
|
||||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
|
||||||
text.push_str(&new_text);
|
text.push_str(&new_text);
|
||||||
} else {
|
} else {
|
||||||
last_message.content.push(MessageContent::Text(new_text));
|
last_message
|
||||||
|
.content
|
||||||
|
.push(AgentMessageContent::Text(new_text));
|
||||||
}
|
}
|
||||||
|
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -454,13 +676,14 @@ impl Thread {
|
||||||
) {
|
) {
|
||||||
event_stream.send_thinking(&new_text);
|
event_stream.send_thinking(&new_text);
|
||||||
|
|
||||||
let last_message = self.last_assistant_message();
|
let last_message = self.pending_agent_message();
|
||||||
if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
|
if let Some(AgentMessageContent::Thinking { text, signature }) =
|
||||||
|
last_message.content.last_mut()
|
||||||
{
|
{
|
||||||
text.push_str(&new_text);
|
text.push_str(&new_text);
|
||||||
*signature = new_signature.or(signature.take());
|
*signature = new_signature.or(signature.take());
|
||||||
} else {
|
} else {
|
||||||
last_message.content.push(MessageContent::Thinking {
|
last_message.content.push(AgentMessageContent::Thinking {
|
||||||
text: new_text,
|
text: new_text,
|
||||||
signature: new_signature,
|
signature: new_signature,
|
||||||
});
|
});
|
||||||
|
@ -470,10 +693,10 @@ impl Thread {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
|
fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
|
||||||
let last_message = self.last_assistant_message();
|
let last_message = self.pending_agent_message();
|
||||||
last_message
|
last_message
|
||||||
.content
|
.content
|
||||||
.push(MessageContent::RedactedThinking(data));
|
.push(AgentMessageContent::RedactedThinking(data));
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -486,14 +709,17 @@ impl Thread {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
|
||||||
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
||||||
|
let mut title = SharedString::from(&tool_use.name);
|
||||||
self.pending_tool_uses
|
let mut kind = acp::ToolKind::Other;
|
||||||
.insert(tool_use.id.clone(), tool_use.clone());
|
if let Some(tool) = tool.as_ref() {
|
||||||
let last_message = self.last_assistant_message();
|
title = tool.initial_title(tool_use.input.clone());
|
||||||
|
kind = tool.kind();
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure the last message ends in the current tool use
|
// Ensure the last message ends in the current tool use
|
||||||
|
let last_message = self.pending_agent_message();
|
||||||
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
||||||
if let MessageContent::ToolUse(last_tool_use) = content {
|
if let AgentMessageContent::ToolUse(last_tool_use) = content {
|
||||||
if last_tool_use.id == tool_use.id {
|
if last_tool_use.id == tool_use.id {
|
||||||
*last_tool_use = tool_use.clone();
|
*last_tool_use = tool_use.clone();
|
||||||
false
|
false
|
||||||
|
@ -505,18 +731,11 @@ impl Thread {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut title = SharedString::from(&tool_use.name);
|
|
||||||
let mut kind = acp::ToolKind::Other;
|
|
||||||
if let Some(tool) = tool.as_ref() {
|
|
||||||
title = tool.initial_title(tool_use.input.clone());
|
|
||||||
kind = tool.kind();
|
|
||||||
}
|
|
||||||
|
|
||||||
if push_new_tool_use {
|
if push_new_tool_use {
|
||||||
event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
|
event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
|
||||||
last_message
|
last_message
|
||||||
.content
|
.content
|
||||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
.push(AgentMessageContent::ToolUse(tool_use.clone()));
|
||||||
} else {
|
} else {
|
||||||
event_stream.update_tool_call_fields(
|
event_stream.update_tool_call_fields(
|
||||||
&tool_use.id,
|
&tool_use.id,
|
||||||
|
@ -601,30 +820,37 @@ impl Thread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Guarantees the last message is from the assistant and returns a mutable reference.
|
fn pending_agent_message(&mut self) -> &mut AgentMessage {
|
||||||
fn last_assistant_message(&mut self) -> &mut AgentMessage {
|
self.pending_agent_message.get_or_insert_default()
|
||||||
if self
|
|
||||||
.messages
|
|
||||||
.last()
|
|
||||||
.map_or(true, |m| m.role != Role::Assistant)
|
|
||||||
{
|
|
||||||
self.messages.push(AgentMessage {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content: Vec::new(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
self.messages.last_mut().unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Guarantees the last message is from the user and returns a mutable reference.
|
fn flush_pending_agent_message(&mut self) {
|
||||||
fn last_user_message(&mut self) -> &mut AgentMessage {
|
let Some(mut message) = self.pending_agent_message.take() else {
|
||||||
if self.messages.last().map_or(true, |m| m.role != Role::User) {
|
return;
|
||||||
self.messages.push(AgentMessage {
|
};
|
||||||
role: Role::User,
|
|
||||||
content: Vec::new(),
|
for content in &message.content {
|
||||||
});
|
let AgentMessageContent::ToolUse(tool_use) = content else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
if !message.tool_results.contains_key(&tool_use.id) {
|
||||||
|
message.tool_results.insert(
|
||||||
|
tool_use.id.clone(),
|
||||||
|
LanguageModelToolResult {
|
||||||
|
tool_use_id: tool_use.id.clone(),
|
||||||
|
tool_name: tool_use.name.clone(),
|
||||||
|
is_error: true,
|
||||||
|
content: LanguageModelToolResultContent::Text(
|
||||||
|
"Tool canceled by user".into(),
|
||||||
|
),
|
||||||
|
output: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
self.messages.last_mut().unwrap()
|
|
||||||
|
self.messages.push(Message::Agent(message));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn build_completion_request(
|
pub(crate) fn build_completion_request(
|
||||||
|
@ -712,49 +938,39 @@ impl Thread {
|
||||||
"Building request messages from {} thread messages",
|
"Building request messages from {} thread messages",
|
||||||
self.messages.len()
|
self.messages.len()
|
||||||
);
|
);
|
||||||
|
let mut messages = vec![self.build_system_message()];
|
||||||
|
for message in &self.messages {
|
||||||
|
match message {
|
||||||
|
Message::User(message) => messages.push(message.to_request()),
|
||||||
|
Message::Agent(message) => messages.extend(message.to_request()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(message) = self.pending_agent_message.as_ref() {
|
||||||
|
messages.extend(message.to_request());
|
||||||
|
}
|
||||||
|
|
||||||
let messages = Some(self.build_system_message())
|
|
||||||
.iter()
|
|
||||||
.chain(self.messages.iter())
|
|
||||||
.map(|message| {
|
|
||||||
log::trace!(
|
|
||||||
" - {} message with {} content items",
|
|
||||||
match message.role {
|
|
||||||
Role::System => "System",
|
|
||||||
Role::User => "User",
|
|
||||||
Role::Assistant => "Assistant",
|
|
||||||
},
|
|
||||||
message.content.len()
|
|
||||||
);
|
|
||||||
message.to_request()
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
messages
|
messages
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_markdown(&self) -> String {
|
pub fn to_markdown(&self) -> String {
|
||||||
let mut markdown = String::new();
|
let mut markdown = String::new();
|
||||||
for message in &self.messages {
|
for (ix, message) in self.messages.iter().enumerate() {
|
||||||
|
if ix > 0 {
|
||||||
|
markdown.push('\n');
|
||||||
|
}
|
||||||
markdown.push_str(&message.to_markdown());
|
markdown.push_str(&message.to_markdown());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(message) = self.pending_agent_message.as_ref() {
|
||||||
|
markdown.push('\n');
|
||||||
|
markdown.push_str(&message.to_markdown());
|
||||||
|
}
|
||||||
|
|
||||||
markdown
|
markdown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct UserMessage(Vec<MessageContent>);
|
|
||||||
|
|
||||||
impl From<Vec<MessageContent>> for UserMessage {
|
|
||||||
fn from(content: Vec<MessageContent>) -> Self {
|
|
||||||
UserMessage(content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Into<MessageContent>> From<T> for UserMessage {
|
|
||||||
fn from(content: T) -> Self {
|
|
||||||
UserMessage(vec![content.into()])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait AgentTool
|
pub trait AgentTool
|
||||||
where
|
where
|
||||||
Self: 'static + Sized,
|
Self: 'static + Sized,
|
||||||
|
@ -1151,130 +1367,6 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentMessage {
|
|
||||||
fn to_request(&self) -> language_model::LanguageModelRequestMessage {
|
|
||||||
let mut message = LanguageModelRequestMessage {
|
|
||||||
role: self.role,
|
|
||||||
content: Vec::with_capacity(self.content.len()),
|
|
||||||
cache: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
const OPEN_CONTEXT: &str = "<context>\n\
|
|
||||||
The following items were attached by the user. \
|
|
||||||
They are up-to-date and don't need to be re-read.\n\n";
|
|
||||||
|
|
||||||
const OPEN_FILES_TAG: &str = "<files>";
|
|
||||||
const OPEN_SYMBOLS_TAG: &str = "<symbols>";
|
|
||||||
const OPEN_THREADS_TAG: &str = "<threads>";
|
|
||||||
const OPEN_RULES_TAG: &str =
|
|
||||||
"<rules>\nThe user has specified the following rules that should be applied:\n";
|
|
||||||
|
|
||||||
let mut file_context = OPEN_FILES_TAG.to_string();
|
|
||||||
let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
|
|
||||||
let mut thread_context = OPEN_THREADS_TAG.to_string();
|
|
||||||
let mut rules_context = OPEN_RULES_TAG.to_string();
|
|
||||||
|
|
||||||
for chunk in &self.content {
|
|
||||||
let chunk = match chunk {
|
|
||||||
MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()),
|
|
||||||
MessageContent::Thinking { text, signature } => {
|
|
||||||
language_model::MessageContent::Thinking {
|
|
||||||
text: text.clone(),
|
|
||||||
signature: signature.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MessageContent::RedactedThinking(value) => {
|
|
||||||
language_model::MessageContent::RedactedThinking(value.clone())
|
|
||||||
}
|
|
||||||
MessageContent::ToolUse(value) => {
|
|
||||||
language_model::MessageContent::ToolUse(value.clone())
|
|
||||||
}
|
|
||||||
MessageContent::ToolResult(value) => {
|
|
||||||
language_model::MessageContent::ToolResult(value.clone())
|
|
||||||
}
|
|
||||||
MessageContent::Image(value) => {
|
|
||||||
language_model::MessageContent::Image(value.clone())
|
|
||||||
}
|
|
||||||
MessageContent::Mention { uri, content } => {
|
|
||||||
match uri {
|
|
||||||
MentionUri::File(path) | MentionUri::Symbol(path, _) => {
|
|
||||||
write!(
|
|
||||||
&mut symbol_context,
|
|
||||||
"\n{}",
|
|
||||||
MarkdownCodeBlock {
|
|
||||||
tag: &codeblock_tag(&path),
|
|
||||||
text: &content.to_string(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
MentionUri::Thread(_session_id) => {
|
|
||||||
write!(&mut thread_context, "\n{}\n", content).ok();
|
|
||||||
}
|
|
||||||
MentionUri::Rule(_user_prompt_id) => {
|
|
||||||
write!(
|
|
||||||
&mut rules_context,
|
|
||||||
"\n{}",
|
|
||||||
MarkdownCodeBlock {
|
|
||||||
tag: "",
|
|
||||||
text: &content
|
|
||||||
}
|
|
||||||
)
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
language_model::MessageContent::Text(uri.to_link())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
message.content.push(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
let len_before_context = message.content.len();
|
|
||||||
|
|
||||||
if file_context.len() > OPEN_FILES_TAG.len() {
|
|
||||||
file_context.push_str("</files>\n");
|
|
||||||
message
|
|
||||||
.content
|
|
||||||
.push(language_model::MessageContent::Text(file_context));
|
|
||||||
}
|
|
||||||
|
|
||||||
if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
|
|
||||||
symbol_context.push_str("</symbols>\n");
|
|
||||||
message
|
|
||||||
.content
|
|
||||||
.push(language_model::MessageContent::Text(symbol_context));
|
|
||||||
}
|
|
||||||
|
|
||||||
if thread_context.len() > OPEN_THREADS_TAG.len() {
|
|
||||||
thread_context.push_str("</threads>\n");
|
|
||||||
message
|
|
||||||
.content
|
|
||||||
.push(language_model::MessageContent::Text(thread_context));
|
|
||||||
}
|
|
||||||
|
|
||||||
if rules_context.len() > OPEN_RULES_TAG.len() {
|
|
||||||
rules_context.push_str("</user_rules>\n");
|
|
||||||
message
|
|
||||||
.content
|
|
||||||
.push(language_model::MessageContent::Text(rules_context));
|
|
||||||
}
|
|
||||||
|
|
||||||
if message.content.len() > len_before_context {
|
|
||||||
message.content.insert(
|
|
||||||
len_before_context,
|
|
||||||
language_model::MessageContent::Text(OPEN_CONTEXT.into()),
|
|
||||||
);
|
|
||||||
message
|
|
||||||
.content
|
|
||||||
.push(language_model::MessageContent::Text("</context>".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn codeblock_tag(full_path: &Path) -> String {
|
fn codeblock_tag(full_path: &Path) -> String {
|
||||||
let mut result = String::new();
|
let mut result = String::new();
|
||||||
|
|
||||||
|
@ -1287,16 +1379,20 @@ fn codeblock_tag(full_path: &Path) -> String {
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<acp::ContentBlock> for MessageContent {
|
impl From<&str> for UserMessageContent {
|
||||||
|
fn from(text: &str) -> Self {
|
||||||
|
Self::Text(text.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<acp::ContentBlock> for UserMessageContent {
|
||||||
fn from(value: acp::ContentBlock) -> Self {
|
fn from(value: acp::ContentBlock) -> Self {
|
||||||
match value {
|
match value {
|
||||||
acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text),
|
acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
|
||||||
acp::ContentBlock::Image(image_content) => {
|
acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
|
||||||
MessageContent::Image(convert_image(image_content))
|
|
||||||
}
|
|
||||||
acp::ContentBlock::Audio(_) => {
|
acp::ContentBlock::Audio(_) => {
|
||||||
// TODO
|
// TODO
|
||||||
MessageContent::Text("[audio]".to_string())
|
Self::Text("[audio]".to_string())
|
||||||
}
|
}
|
||||||
acp::ContentBlock::ResourceLink(resource_link) => {
|
acp::ContentBlock::ResourceLink(resource_link) => {
|
||||||
match MentionUri::parse(&resource_link.uri) {
|
match MentionUri::parse(&resource_link.uri) {
|
||||||
|
@ -1306,10 +1402,7 @@ impl From<acp::ContentBlock> for MessageContent {
|
||||||
},
|
},
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
log::error!("Failed to parse mention link: {}", err);
|
log::error!("Failed to parse mention link: {}", err);
|
||||||
MessageContent::Text(format!(
|
Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
|
||||||
"[{}]({})",
|
|
||||||
resource_link.name, resource_link.uri
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1322,7 +1415,7 @@ impl From<acp::ContentBlock> for MessageContent {
|
||||||
},
|
},
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
log::error!("Failed to parse mention link: {}", err);
|
log::error!("Failed to parse mention link: {}", err);
|
||||||
MessageContent::Text(
|
Self::Text(
|
||||||
MarkdownCodeBlock {
|
MarkdownCodeBlock {
|
||||||
tag: &resource.uri,
|
tag: &resource.uri,
|
||||||
text: &resource.text,
|
text: &resource.text,
|
||||||
|
@ -1334,7 +1427,7 @@ impl From<acp::ContentBlock> for MessageContent {
|
||||||
}
|
}
|
||||||
acp::EmbeddedResourceResource::BlobResourceContents(_) => {
|
acp::EmbeddedResourceResource::BlobResourceContents(_) => {
|
||||||
// TODO
|
// TODO
|
||||||
MessageContent::Text("[blob]".to_string())
|
Self::Text("[blob]".to_string())
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -1348,9 +1441,3 @@ fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
|
||||||
size: gpui::Size::new(0.into(), 0.into()),
|
size: gpui::Size::new(0.into(), 0.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<&str> for MessageContent {
|
|
||||||
fn from(text: &str) -> Self {
|
|
||||||
MessageContent::Text(text.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -467,6 +467,7 @@ impl AgentConnection for AcpConnection {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
|
_id: Option<acp_thread::UserMessageId>,
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
|
|
|
@ -171,6 +171,7 @@ impl AgentConnection for AcpConnection {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
|
_id: Option<acp_thread::UserMessageId>,
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
|
|
|
@ -210,6 +210,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
|
_id: Option<acp_thread::UserMessageId>,
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
|
@ -423,7 +424,7 @@ impl ClaudeAgentSession {
|
||||||
if !turn_state.borrow().is_cancelled() {
|
if !turn_state.borrow().is_cancelled() {
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.push_user_content_block(text.into(), cx)
|
thread.push_user_content_block(None, text.into(), cx)
|
||||||
})
|
})
|
||||||
.log_err();
|
.log_err();
|
||||||
}
|
}
|
||||||
|
|
|
@ -679,17 +679,19 @@ impl AcpThreadView {
|
||||||
window: &mut Window,
|
window: &mut Window,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
let count = self.list_state.item_count();
|
|
||||||
match event {
|
match event {
|
||||||
AcpThreadEvent::NewEntry => {
|
AcpThreadEvent::NewEntry => {
|
||||||
let index = thread.read(cx).entries().len() - 1;
|
let index = thread.read(cx).entries().len() - 1;
|
||||||
self.sync_thread_entry_view(index, window, cx);
|
self.sync_thread_entry_view(index, window, cx);
|
||||||
self.list_state.splice(count..count, 1);
|
self.list_state.splice(index..index, 1);
|
||||||
}
|
}
|
||||||
AcpThreadEvent::EntryUpdated(index) => {
|
AcpThreadEvent::EntryUpdated(index) => {
|
||||||
let index = *index;
|
self.sync_thread_entry_view(*index, window, cx);
|
||||||
self.sync_thread_entry_view(index, window, cx);
|
self.list_state.splice(*index..index + 1, 1);
|
||||||
self.list_state.splice(index..index + 1, 1);
|
}
|
||||||
|
AcpThreadEvent::EntriesRemoved(range) => {
|
||||||
|
// TODO: Clean up unused diff editors and terminal views
|
||||||
|
self.list_state.splice(range.clone(), 0);
|
||||||
}
|
}
|
||||||
AcpThreadEvent::ToolAuthorizationRequired => {
|
AcpThreadEvent::ToolAuthorizationRequired => {
|
||||||
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
|
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||||
|
@ -3789,6 +3791,7 @@ mod tests {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
|
_id: Option<acp_thread::UserMessageId>,
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||||
|
@ -3873,6 +3876,7 @@ mod tests {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
|
_id: Option<acp_thread::UserMessageId>,
|
||||||
_params: acp::PromptRequest,
|
_params: acp::PromptRequest,
|
||||||
_cx: &mut App,
|
_cx: &mut App,
|
||||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||||
|
|
|
@ -1521,7 +1521,8 @@ impl AgentDiff {
|
||||||
self.update_reviewing_editors(workspace, window, cx);
|
self.update_reviewing_editors(workspace, window, cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AcpThreadEvent::Stopped
|
AcpThreadEvent::EntriesRemoved(_)
|
||||||
|
| AcpThreadEvent::Stopped
|
||||||
| AcpThreadEvent::ToolAuthorizationRequired
|
| AcpThreadEvent::ToolAuthorizationRequired
|
||||||
| AcpThreadEvent::Error
|
| AcpThreadEvent::Error
|
||||||
| AcpThreadEvent::ServerExited(_) => {}
|
| AcpThreadEvent::ServerExited(_) => {}
|
||||||
|
|
|
@ -51,6 +51,7 @@ ashpd.workspace = true
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
gpui = { workspace = true, features = ["test-support"] }
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
|
git = { workspace = true, features = ["test-support"] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
test-support = ["gpui/test-support", "git/test-support"]
|
test-support = ["gpui/test-support", "git/test-support"]
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
use crate::{FakeFs, Fs};
|
use crate::{FakeFs, FakeFsEntry, Fs};
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
use futures::future::{self, BoxFuture, join_all};
|
use futures::future::{self, BoxFuture, join_all};
|
||||||
use git::{
|
use git::{
|
||||||
|
Oid,
|
||||||
blame::Blame,
|
blame::Blame,
|
||||||
repository::{
|
repository::{
|
||||||
AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository,
|
AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository,
|
||||||
|
@ -12,6 +13,7 @@ use git::{
|
||||||
};
|
};
|
||||||
use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task};
|
use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task};
|
||||||
use ignore::gitignore::GitignoreBuilder;
|
use ignore::gitignore::GitignoreBuilder;
|
||||||
|
use parking_lot::Mutex;
|
||||||
use rope::Rope;
|
use rope::Rope;
|
||||||
use smol::future::FutureExt as _;
|
use smol::future::FutureExt as _;
|
||||||
use std::{path::PathBuf, sync::Arc};
|
use std::{path::PathBuf, sync::Arc};
|
||||||
|
@ -19,6 +21,7 @@ use std::{path::PathBuf, sync::Arc};
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct FakeGitRepository {
|
pub struct FakeGitRepository {
|
||||||
pub(crate) fs: Arc<FakeFs>,
|
pub(crate) fs: Arc<FakeFs>,
|
||||||
|
pub(crate) checkpoints: Arc<Mutex<HashMap<Oid, FakeFsEntry>>>,
|
||||||
pub(crate) executor: BackgroundExecutor,
|
pub(crate) executor: BackgroundExecutor,
|
||||||
pub(crate) dot_git_path: PathBuf,
|
pub(crate) dot_git_path: PathBuf,
|
||||||
pub(crate) repository_dir_path: PathBuf,
|
pub(crate) repository_dir_path: PathBuf,
|
||||||
|
@ -469,22 +472,57 @@ impl GitRepository for FakeGitRepository {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn checkpoint(&self) -> BoxFuture<'static, Result<GitRepositoryCheckpoint>> {
|
fn checkpoint(&self) -> BoxFuture<'static, Result<GitRepositoryCheckpoint>> {
|
||||||
unimplemented!()
|
let executor = self.executor.clone();
|
||||||
|
let fs = self.fs.clone();
|
||||||
|
let checkpoints = self.checkpoints.clone();
|
||||||
|
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
|
||||||
|
async move {
|
||||||
|
executor.simulate_random_delay().await;
|
||||||
|
let oid = Oid::random(&mut executor.rng());
|
||||||
|
let entry = fs.entry(&repository_dir_path)?;
|
||||||
|
checkpoints.lock().insert(oid, entry);
|
||||||
|
Ok(GitRepositoryCheckpoint { commit_sha: oid })
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn restore_checkpoint(
|
fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> {
|
||||||
&self,
|
let executor = self.executor.clone();
|
||||||
_checkpoint: GitRepositoryCheckpoint,
|
let fs = self.fs.clone();
|
||||||
) -> BoxFuture<'_, Result<()>> {
|
let checkpoints = self.checkpoints.clone();
|
||||||
unimplemented!()
|
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
|
||||||
|
async move {
|
||||||
|
executor.simulate_random_delay().await;
|
||||||
|
let checkpoints = checkpoints.lock();
|
||||||
|
let entry = checkpoints
|
||||||
|
.get(&checkpoint.commit_sha)
|
||||||
|
.context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?;
|
||||||
|
fs.insert_entry(&repository_dir_path, entry.clone())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compare_checkpoints(
|
fn compare_checkpoints(
|
||||||
&self,
|
&self,
|
||||||
_left: GitRepositoryCheckpoint,
|
left: GitRepositoryCheckpoint,
|
||||||
_right: GitRepositoryCheckpoint,
|
right: GitRepositoryCheckpoint,
|
||||||
) -> BoxFuture<'_, Result<bool>> {
|
) -> BoxFuture<'_, Result<bool>> {
|
||||||
unimplemented!()
|
let executor = self.executor.clone();
|
||||||
|
let checkpoints = self.checkpoints.clone();
|
||||||
|
async move {
|
||||||
|
executor.simulate_random_delay().await;
|
||||||
|
let checkpoints = checkpoints.lock();
|
||||||
|
let left = checkpoints
|
||||||
|
.get(&left.commit_sha)
|
||||||
|
.context(format!("invalid left checkpoint: {}", left.commit_sha))?;
|
||||||
|
let right = checkpoints
|
||||||
|
.get(&right.commit_sha)
|
||||||
|
.context(format!("invalid right checkpoint: {}", right.commit_sha))?;
|
||||||
|
|
||||||
|
Ok(left == right)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn diff_checkpoints(
|
fn diff_checkpoints(
|
||||||
|
@ -499,3 +537,63 @@ impl GitRepository for FakeGitRepository {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::{FakeFs, Fs};
|
||||||
|
use gpui::BackgroundExecutor;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::path::Path;
|
||||||
|
use util::path;
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_checkpoints(executor: BackgroundExecutor) {
|
||||||
|
let fs = FakeFs::new(executor);
|
||||||
|
fs.insert_tree(
|
||||||
|
path!("/"),
|
||||||
|
json!({
|
||||||
|
"bar": {
|
||||||
|
"baz": "qux"
|
||||||
|
},
|
||||||
|
"foo": {
|
||||||
|
".git": {},
|
||||||
|
"a": "lorem",
|
||||||
|
"b": "ipsum",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
fs.with_git_state(Path::new("/foo/.git"), true, |_git| {})
|
||||||
|
.unwrap();
|
||||||
|
let repository = fs.open_repo(Path::new("/foo/.git")).unwrap();
|
||||||
|
|
||||||
|
let checkpoint_1 = repository.checkpoint().await.unwrap();
|
||||||
|
fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap();
|
||||||
|
fs.write(Path::new("/foo/c"), b"dolor").await.unwrap();
|
||||||
|
let checkpoint_2 = repository.checkpoint().await.unwrap();
|
||||||
|
let checkpoint_3 = repository.checkpoint().await.unwrap();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
repository
|
||||||
|
.compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!repository
|
||||||
|
.compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
repository.restore_checkpoint(checkpoint_1).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
fs.files_with_contents(Path::new("")),
|
||||||
|
[
|
||||||
|
(Path::new("/bar/baz").into(), b"qux".into()),
|
||||||
|
(Path::new("/foo/a").into(), b"lorem".into()),
|
||||||
|
(Path::new("/foo/b").into(), b"ipsum".into())
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -924,7 +924,7 @@ pub struct FakeFs {
|
||||||
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
struct FakeFsState {
|
struct FakeFsState {
|
||||||
root: Arc<Mutex<FakeFsEntry>>,
|
root: FakeFsEntry,
|
||||||
next_inode: u64,
|
next_inode: u64,
|
||||||
next_mtime: SystemTime,
|
next_mtime: SystemTime,
|
||||||
git_event_tx: smol::channel::Sender<PathBuf>,
|
git_event_tx: smol::channel::Sender<PathBuf>,
|
||||||
|
@ -939,7 +939,7 @@ struct FakeFsState {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
enum FakeFsEntry {
|
enum FakeFsEntry {
|
||||||
File {
|
File {
|
||||||
inode: u64,
|
inode: u64,
|
||||||
|
@ -953,7 +953,7 @@ enum FakeFsEntry {
|
||||||
inode: u64,
|
inode: u64,
|
||||||
mtime: MTime,
|
mtime: MTime,
|
||||||
len: u64,
|
len: u64,
|
||||||
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
|
entries: BTreeMap<String, FakeFsEntry>,
|
||||||
git_repo_state: Option<Arc<Mutex<FakeGitRepositoryState>>>,
|
git_repo_state: Option<Arc<Mutex<FakeGitRepositoryState>>>,
|
||||||
},
|
},
|
||||||
Symlink {
|
Symlink {
|
||||||
|
@ -961,6 +961,67 @@ enum FakeFsEntry {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
impl PartialEq for FakeFsEntry {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
match (self, other) {
|
||||||
|
(
|
||||||
|
Self::File {
|
||||||
|
inode: l_inode,
|
||||||
|
mtime: l_mtime,
|
||||||
|
len: l_len,
|
||||||
|
content: l_content,
|
||||||
|
git_dir_path: l_git_dir_path,
|
||||||
|
},
|
||||||
|
Self::File {
|
||||||
|
inode: r_inode,
|
||||||
|
mtime: r_mtime,
|
||||||
|
len: r_len,
|
||||||
|
content: r_content,
|
||||||
|
git_dir_path: r_git_dir_path,
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
l_inode == r_inode
|
||||||
|
&& l_mtime == r_mtime
|
||||||
|
&& l_len == r_len
|
||||||
|
&& l_content == r_content
|
||||||
|
&& l_git_dir_path == r_git_dir_path
|
||||||
|
}
|
||||||
|
(
|
||||||
|
Self::Dir {
|
||||||
|
inode: l_inode,
|
||||||
|
mtime: l_mtime,
|
||||||
|
len: l_len,
|
||||||
|
entries: l_entries,
|
||||||
|
git_repo_state: l_git_repo_state,
|
||||||
|
},
|
||||||
|
Self::Dir {
|
||||||
|
inode: r_inode,
|
||||||
|
mtime: r_mtime,
|
||||||
|
len: r_len,
|
||||||
|
entries: r_entries,
|
||||||
|
git_repo_state: r_git_repo_state,
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) {
|
||||||
|
(Some(l), Some(r)) => Arc::ptr_eq(l, r),
|
||||||
|
(None, None) => true,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
l_inode == r_inode
|
||||||
|
&& l_mtime == r_mtime
|
||||||
|
&& l_len == r_len
|
||||||
|
&& l_entries == r_entries
|
||||||
|
&& same_repo_state
|
||||||
|
}
|
||||||
|
(Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => {
|
||||||
|
l_target == r_target
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
impl FakeFsState {
|
impl FakeFsState {
|
||||||
fn get_and_increment_mtime(&mut self) -> MTime {
|
fn get_and_increment_mtime(&mut self) -> MTime {
|
||||||
|
@ -975,25 +1036,9 @@ impl FakeFsState {
|
||||||
inode
|
inode
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_path(&self, target: &Path) -> Result<Arc<Mutex<FakeFsEntry>>> {
|
fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option<PathBuf> {
|
||||||
Ok(self
|
|
||||||
.try_read_path(target, true)
|
|
||||||
.ok_or_else(|| {
|
|
||||||
anyhow!(io::Error::new(
|
|
||||||
io::ErrorKind::NotFound,
|
|
||||||
format!("not found: {target:?}")
|
|
||||||
))
|
|
||||||
})?
|
|
||||||
.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn try_read_path(
|
|
||||||
&self,
|
|
||||||
target: &Path,
|
|
||||||
follow_symlink: bool,
|
|
||||||
) -> Option<(Arc<Mutex<FakeFsEntry>>, PathBuf)> {
|
|
||||||
let mut path = target.to_path_buf();
|
|
||||||
let mut canonical_path = PathBuf::new();
|
let mut canonical_path = PathBuf::new();
|
||||||
|
let mut path = target.to_path_buf();
|
||||||
let mut entry_stack = Vec::new();
|
let mut entry_stack = Vec::new();
|
||||||
'outer: loop {
|
'outer: loop {
|
||||||
let mut path_components = path.components().peekable();
|
let mut path_components = path.components().peekable();
|
||||||
|
@ -1003,7 +1048,7 @@ impl FakeFsState {
|
||||||
Component::Prefix(prefix_component) => prefix = Some(prefix_component),
|
Component::Prefix(prefix_component) => prefix = Some(prefix_component),
|
||||||
Component::RootDir => {
|
Component::RootDir => {
|
||||||
entry_stack.clear();
|
entry_stack.clear();
|
||||||
entry_stack.push(self.root.clone());
|
entry_stack.push(&self.root);
|
||||||
canonical_path.clear();
|
canonical_path.clear();
|
||||||
match prefix {
|
match prefix {
|
||||||
Some(prefix_component) => {
|
Some(prefix_component) => {
|
||||||
|
@ -1020,20 +1065,18 @@ impl FakeFsState {
|
||||||
canonical_path.pop();
|
canonical_path.pop();
|
||||||
}
|
}
|
||||||
Component::Normal(name) => {
|
Component::Normal(name) => {
|
||||||
let current_entry = entry_stack.last().cloned()?;
|
let current_entry = *entry_stack.last()?;
|
||||||
let current_entry = current_entry.lock();
|
if let FakeFsEntry::Dir { entries, .. } = current_entry {
|
||||||
if let FakeFsEntry::Dir { entries, .. } = &*current_entry {
|
let entry = entries.get(name.to_str().unwrap())?;
|
||||||
let entry = entries.get(name.to_str().unwrap()).cloned()?;
|
|
||||||
if path_components.peek().is_some() || follow_symlink {
|
if path_components.peek().is_some() || follow_symlink {
|
||||||
let entry = entry.lock();
|
if let FakeFsEntry::Symlink { target, .. } = entry {
|
||||||
if let FakeFsEntry::Symlink { target, .. } = &*entry {
|
|
||||||
let mut target = target.clone();
|
let mut target = target.clone();
|
||||||
target.extend(path_components);
|
target.extend(path_components);
|
||||||
path = target;
|
path = target;
|
||||||
continue 'outer;
|
continue 'outer;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
entry_stack.push(entry.clone());
|
entry_stack.push(entry);
|
||||||
canonical_path = canonical_path.join(name);
|
canonical_path = canonical_path.join(name);
|
||||||
} else {
|
} else {
|
||||||
return None;
|
return None;
|
||||||
|
@ -1043,19 +1086,72 @@ impl FakeFsState {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Some((entry_stack.pop()?, canonical_path))
|
|
||||||
|
if entry_stack.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(canonical_path)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write_path<Fn, T>(&self, path: &Path, callback: Fn) -> Result<T>
|
fn try_entry(
|
||||||
|
&mut self,
|
||||||
|
target: &Path,
|
||||||
|
follow_symlink: bool,
|
||||||
|
) -> Option<(&mut FakeFsEntry, PathBuf)> {
|
||||||
|
let canonical_path = self.canonicalize(target, follow_symlink)?;
|
||||||
|
|
||||||
|
let mut components = canonical_path.components();
|
||||||
|
let Some(Component::RootDir) = components.next() else {
|
||||||
|
panic!(
|
||||||
|
"the path {:?} was not canonicalized properly {:?}",
|
||||||
|
target, canonical_path
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut entry = &mut self.root;
|
||||||
|
for component in components {
|
||||||
|
match component {
|
||||||
|
Component::Normal(name) => {
|
||||||
|
if let FakeFsEntry::Dir { entries, .. } = entry {
|
||||||
|
entry = entries.get_mut(name.to_str().unwrap())?;
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!(
|
||||||
|
"the path {:?} was not canonicalized properly {:?}",
|
||||||
|
target, canonical_path
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some((entry, canonical_path))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> {
|
||||||
|
Ok(self
|
||||||
|
.try_entry(target, true)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow!(io::Error::new(
|
||||||
|
io::ErrorKind::NotFound,
|
||||||
|
format!("not found: {target:?}")
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_path<Fn, T>(&mut self, path: &Path, callback: Fn) -> Result<T>
|
||||||
where
|
where
|
||||||
Fn: FnOnce(btree_map::Entry<String, Arc<Mutex<FakeFsEntry>>>) -> Result<T>,
|
Fn: FnOnce(btree_map::Entry<String, FakeFsEntry>) -> Result<T>,
|
||||||
{
|
{
|
||||||
let path = normalize_path(path);
|
let path = normalize_path(path);
|
||||||
let filename = path.file_name().context("cannot overwrite the root")?;
|
let filename = path.file_name().context("cannot overwrite the root")?;
|
||||||
let parent_path = path.parent().unwrap();
|
let parent_path = path.parent().unwrap();
|
||||||
|
|
||||||
let parent = self.read_path(parent_path)?;
|
let parent = self.entry(parent_path)?;
|
||||||
let mut parent = parent.lock();
|
|
||||||
let new_entry = parent
|
let new_entry = parent
|
||||||
.dir_entries(parent_path)?
|
.dir_entries(parent_path)?
|
||||||
.entry(filename.to_str().unwrap().into());
|
.entry(filename.to_str().unwrap().into());
|
||||||
|
@ -1105,13 +1201,13 @@ impl FakeFs {
|
||||||
this: this.clone(),
|
this: this.clone(),
|
||||||
executor: executor.clone(),
|
executor: executor.clone(),
|
||||||
state: Arc::new(Mutex::new(FakeFsState {
|
state: Arc::new(Mutex::new(FakeFsState {
|
||||||
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
|
root: FakeFsEntry::Dir {
|
||||||
inode: 0,
|
inode: 0,
|
||||||
mtime: MTime(UNIX_EPOCH),
|
mtime: MTime(UNIX_EPOCH),
|
||||||
len: 0,
|
len: 0,
|
||||||
entries: Default::default(),
|
entries: Default::default(),
|
||||||
git_repo_state: None,
|
git_repo_state: None,
|
||||||
})),
|
},
|
||||||
git_event_tx: tx,
|
git_event_tx: tx,
|
||||||
next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL,
|
next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL,
|
||||||
next_inode: 1,
|
next_inode: 1,
|
||||||
|
@ -1161,15 +1257,15 @@ impl FakeFs {
|
||||||
.write_path(path, move |entry| {
|
.write_path(path, move |entry| {
|
||||||
match entry {
|
match entry {
|
||||||
btree_map::Entry::Vacant(e) => {
|
btree_map::Entry::Vacant(e) => {
|
||||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
e.insert(FakeFsEntry::File {
|
||||||
inode: new_inode,
|
inode: new_inode,
|
||||||
mtime: new_mtime,
|
mtime: new_mtime,
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
len: 0,
|
len: 0,
|
||||||
git_dir_path: None,
|
git_dir_path: None,
|
||||||
})));
|
});
|
||||||
}
|
}
|
||||||
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
|
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() {
|
||||||
FakeFsEntry::File { mtime, .. } => *mtime = new_mtime,
|
FakeFsEntry::File { mtime, .. } => *mtime = new_mtime,
|
||||||
FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime,
|
FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime,
|
||||||
FakeFsEntry::Symlink { .. } => {}
|
FakeFsEntry::Symlink { .. } => {}
|
||||||
|
@ -1188,7 +1284,7 @@ impl FakeFs {
|
||||||
pub async fn insert_symlink(&self, path: impl AsRef<Path>, target: PathBuf) {
|
pub async fn insert_symlink(&self, path: impl AsRef<Path>, target: PathBuf) {
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let path = path.as_ref();
|
let path = path.as_ref();
|
||||||
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
|
let file = FakeFsEntry::Symlink { target };
|
||||||
state
|
state
|
||||||
.write_path(path.as_ref(), move |e| match e {
|
.write_path(path.as_ref(), move |e| match e {
|
||||||
btree_map::Entry::Vacant(e) => {
|
btree_map::Entry::Vacant(e) => {
|
||||||
|
@ -1221,13 +1317,13 @@ impl FakeFs {
|
||||||
match entry {
|
match entry {
|
||||||
btree_map::Entry::Vacant(e) => {
|
btree_map::Entry::Vacant(e) => {
|
||||||
kind = Some(PathEventKind::Created);
|
kind = Some(PathEventKind::Created);
|
||||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
e.insert(FakeFsEntry::File {
|
||||||
inode: new_inode,
|
inode: new_inode,
|
||||||
mtime: new_mtime,
|
mtime: new_mtime,
|
||||||
len: new_len,
|
len: new_len,
|
||||||
content: new_content,
|
content: new_content,
|
||||||
git_dir_path: None,
|
git_dir_path: None,
|
||||||
})));
|
});
|
||||||
}
|
}
|
||||||
btree_map::Entry::Occupied(mut e) => {
|
btree_map::Entry::Occupied(mut e) => {
|
||||||
kind = Some(PathEventKind::Changed);
|
kind = Some(PathEventKind::Changed);
|
||||||
|
@ -1237,7 +1333,7 @@ impl FakeFs {
|
||||||
len,
|
len,
|
||||||
content,
|
content,
|
||||||
..
|
..
|
||||||
} = &mut *e.get_mut().lock()
|
} = e.get_mut()
|
||||||
{
|
{
|
||||||
*mtime = new_mtime;
|
*mtime = new_mtime;
|
||||||
*content = new_content;
|
*content = new_content;
|
||||||
|
@ -1259,9 +1355,8 @@ impl FakeFs {
|
||||||
pub fn read_file_sync(&self, path: impl AsRef<Path>) -> Result<Vec<u8>> {
|
pub fn read_file_sync(&self, path: impl AsRef<Path>) -> Result<Vec<u8>> {
|
||||||
let path = path.as_ref();
|
let path = path.as_ref();
|
||||||
let path = normalize_path(path);
|
let path = normalize_path(path);
|
||||||
let state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let entry = state.read_path(&path)?;
|
let entry = state.entry(&path)?;
|
||||||
let entry = entry.lock();
|
|
||||||
entry.file_content(&path).cloned()
|
entry.file_content(&path).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1269,9 +1364,8 @@ impl FakeFs {
|
||||||
let path = path.as_ref();
|
let path = path.as_ref();
|
||||||
let path = normalize_path(path);
|
let path = normalize_path(path);
|
||||||
self.simulate_random_delay().await;
|
self.simulate_random_delay().await;
|
||||||
let state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let entry = state.read_path(&path)?;
|
let entry = state.entry(&path)?;
|
||||||
let entry = entry.lock();
|
|
||||||
entry.file_content(&path).cloned()
|
entry.file_content(&path).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1292,6 +1386,25 @@ impl FakeFs {
|
||||||
self.state.lock().flush_events(count);
|
self.state.lock().flush_events(count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn entry(&self, target: &Path) -> Result<FakeFsEntry> {
|
||||||
|
self.state.lock().entry(target).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> {
|
||||||
|
let mut state = self.state.lock();
|
||||||
|
state.write_path(target, |entry| {
|
||||||
|
match entry {
|
||||||
|
btree_map::Entry::Vacant(vacant_entry) => {
|
||||||
|
vacant_entry.insert(new_entry);
|
||||||
|
}
|
||||||
|
btree_map::Entry::Occupied(mut occupied_entry) => {
|
||||||
|
occupied_entry.insert(new_entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn insert_tree<'a>(
|
pub fn insert_tree<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
|
@ -1361,20 +1474,19 @@ impl FakeFs {
|
||||||
F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T,
|
F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T,
|
||||||
{
|
{
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let entry = state.read_path(dot_git).context("open .git")?;
|
let git_event_tx = state.git_event_tx.clone();
|
||||||
let mut entry = entry.lock();
|
let entry = state.entry(dot_git).context("open .git")?;
|
||||||
|
|
||||||
if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry {
|
if let FakeFsEntry::Dir { git_repo_state, .. } = entry {
|
||||||
let repo_state = git_repo_state.get_or_insert_with(|| {
|
let repo_state = git_repo_state.get_or_insert_with(|| {
|
||||||
log::debug!("insert git state for {dot_git:?}");
|
log::debug!("insert git state for {dot_git:?}");
|
||||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(
|
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
|
||||||
state.git_event_tx.clone(),
|
|
||||||
)))
|
|
||||||
});
|
});
|
||||||
let mut repo_state = repo_state.lock();
|
let mut repo_state = repo_state.lock();
|
||||||
|
|
||||||
let result = f(&mut repo_state, dot_git, dot_git);
|
let result = f(&mut repo_state, dot_git, dot_git);
|
||||||
|
|
||||||
|
drop(repo_state);
|
||||||
if emit_git_event {
|
if emit_git_event {
|
||||||
state.emit_event([(dot_git, None)]);
|
state.emit_event([(dot_git, None)]);
|
||||||
}
|
}
|
||||||
|
@ -1398,21 +1510,20 @@ impl FakeFs {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.clone();
|
.clone();
|
||||||
drop(entry);
|
let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else {
|
||||||
let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else {
|
|
||||||
anyhow::bail!("pointed-to git dir {path:?} not found")
|
anyhow::bail!("pointed-to git dir {path:?} not found")
|
||||||
};
|
};
|
||||||
let FakeFsEntry::Dir {
|
let FakeFsEntry::Dir {
|
||||||
git_repo_state,
|
git_repo_state,
|
||||||
entries,
|
entries,
|
||||||
..
|
..
|
||||||
} = &mut *git_dir_entry.lock()
|
} = git_dir_entry
|
||||||
else {
|
else {
|
||||||
anyhow::bail!("gitfile points to a non-directory")
|
anyhow::bail!("gitfile points to a non-directory")
|
||||||
};
|
};
|
||||||
let common_dir = if let Some(child) = entries.get("commondir") {
|
let common_dir = if let Some(child) = entries.get("commondir") {
|
||||||
Path::new(
|
Path::new(
|
||||||
std::str::from_utf8(child.lock().file_content("commondir".as_ref())?)
|
std::str::from_utf8(child.file_content("commondir".as_ref())?)
|
||||||
.context("commondir content")?,
|
.context("commondir content")?,
|
||||||
)
|
)
|
||||||
.to_owned()
|
.to_owned()
|
||||||
|
@ -1420,15 +1531,14 @@ impl FakeFs {
|
||||||
canonical_path.clone()
|
canonical_path.clone()
|
||||||
};
|
};
|
||||||
let repo_state = git_repo_state.get_or_insert_with(|| {
|
let repo_state = git_repo_state.get_or_insert_with(|| {
|
||||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(
|
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
|
||||||
state.git_event_tx.clone(),
|
|
||||||
)))
|
|
||||||
});
|
});
|
||||||
let mut repo_state = repo_state.lock();
|
let mut repo_state = repo_state.lock();
|
||||||
|
|
||||||
let result = f(&mut repo_state, &canonical_path, &common_dir);
|
let result = f(&mut repo_state, &canonical_path, &common_dir);
|
||||||
|
|
||||||
if emit_git_event {
|
if emit_git_event {
|
||||||
|
drop(repo_state);
|
||||||
state.emit_event([(canonical_path, None)]);
|
state.emit_event([(canonical_path, None)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1655,14 +1765,12 @@ impl FakeFs {
|
||||||
pub fn paths(&self, include_dot_git: bool) -> Vec<PathBuf> {
|
pub fn paths(&self, include_dot_git: bool) -> Vec<PathBuf> {
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
let mut queue = collections::VecDeque::new();
|
let mut queue = collections::VecDeque::new();
|
||||||
queue.push_back((
|
let state = &*self.state.lock();
|
||||||
PathBuf::from(util::path!("/")),
|
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||||
self.state.lock().root.clone(),
|
|
||||||
));
|
|
||||||
while let Some((path, entry)) = queue.pop_front() {
|
while let Some((path, entry)) = queue.pop_front() {
|
||||||
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
|
if let FakeFsEntry::Dir { entries, .. } = entry {
|
||||||
for (name, entry) in entries {
|
for (name, entry) in entries {
|
||||||
queue.push_back((path.join(name), entry.clone()));
|
queue.push_back((path.join(name), entry));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if include_dot_git
|
if include_dot_git
|
||||||
|
@ -1679,14 +1787,12 @@ impl FakeFs {
|
||||||
pub fn directories(&self, include_dot_git: bool) -> Vec<PathBuf> {
|
pub fn directories(&self, include_dot_git: bool) -> Vec<PathBuf> {
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
let mut queue = collections::VecDeque::new();
|
let mut queue = collections::VecDeque::new();
|
||||||
queue.push_back((
|
let state = &*self.state.lock();
|
||||||
PathBuf::from(util::path!("/")),
|
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||||
self.state.lock().root.clone(),
|
|
||||||
));
|
|
||||||
while let Some((path, entry)) = queue.pop_front() {
|
while let Some((path, entry)) = queue.pop_front() {
|
||||||
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
|
if let FakeFsEntry::Dir { entries, .. } = entry {
|
||||||
for (name, entry) in entries {
|
for (name, entry) in entries {
|
||||||
queue.push_back((path.join(name), entry.clone()));
|
queue.push_back((path.join(name), entry));
|
||||||
}
|
}
|
||||||
if include_dot_git
|
if include_dot_git
|
||||||
|| !path
|
|| !path
|
||||||
|
@ -1703,17 +1809,14 @@ impl FakeFs {
|
||||||
pub fn files(&self) -> Vec<PathBuf> {
|
pub fn files(&self) -> Vec<PathBuf> {
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
let mut queue = collections::VecDeque::new();
|
let mut queue = collections::VecDeque::new();
|
||||||
queue.push_back((
|
let state = &*self.state.lock();
|
||||||
PathBuf::from(util::path!("/")),
|
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||||
self.state.lock().root.clone(),
|
|
||||||
));
|
|
||||||
while let Some((path, entry)) = queue.pop_front() {
|
while let Some((path, entry)) = queue.pop_front() {
|
||||||
let e = entry.lock();
|
match entry {
|
||||||
match &*e {
|
|
||||||
FakeFsEntry::File { .. } => result.push(path),
|
FakeFsEntry::File { .. } => result.push(path),
|
||||||
FakeFsEntry::Dir { entries, .. } => {
|
FakeFsEntry::Dir { entries, .. } => {
|
||||||
for (name, entry) in entries {
|
for (name, entry) in entries {
|
||||||
queue.push_back((path.join(name), entry.clone()));
|
queue.push_back((path.join(name), entry));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
FakeFsEntry::Symlink { .. } => {}
|
FakeFsEntry::Symlink { .. } => {}
|
||||||
|
@ -1725,13 +1828,10 @@ impl FakeFs {
|
||||||
pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec<u8>)> {
|
pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec<u8>)> {
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
let mut queue = collections::VecDeque::new();
|
let mut queue = collections::VecDeque::new();
|
||||||
queue.push_back((
|
let state = &*self.state.lock();
|
||||||
PathBuf::from(util::path!("/")),
|
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||||
self.state.lock().root.clone(),
|
|
||||||
));
|
|
||||||
while let Some((path, entry)) = queue.pop_front() {
|
while let Some((path, entry)) = queue.pop_front() {
|
||||||
let e = entry.lock();
|
match entry {
|
||||||
match &*e {
|
|
||||||
FakeFsEntry::File { content, .. } => {
|
FakeFsEntry::File { content, .. } => {
|
||||||
if path.starts_with(prefix) {
|
if path.starts_with(prefix) {
|
||||||
result.push((path, content.clone()));
|
result.push((path, content.clone()));
|
||||||
|
@ -1739,7 +1839,7 @@ impl FakeFs {
|
||||||
}
|
}
|
||||||
FakeFsEntry::Dir { entries, .. } => {
|
FakeFsEntry::Dir { entries, .. } => {
|
||||||
for (name, entry) in entries {
|
for (name, entry) in entries {
|
||||||
queue.push_back((path.join(name), entry.clone()));
|
queue.push_back((path.join(name), entry));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
FakeFsEntry::Symlink { .. } => {}
|
FakeFsEntry::Symlink { .. } => {}
|
||||||
|
@ -1805,10 +1905,7 @@ impl FakeFsEntry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dir_entries(
|
fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap<String, FakeFsEntry>> {
|
||||||
&mut self,
|
|
||||||
path: &Path,
|
|
||||||
) -> Result<&mut BTreeMap<String, Arc<Mutex<FakeFsEntry>>>> {
|
|
||||||
if let Self::Dir { entries, .. } = self {
|
if let Self::Dir { entries, .. } = self {
|
||||||
Ok(entries)
|
Ok(entries)
|
||||||
} else {
|
} else {
|
||||||
|
@ -1855,12 +1952,12 @@ struct FakeHandle {
|
||||||
impl FileHandle for FakeHandle {
|
impl FileHandle for FakeHandle {
|
||||||
fn current_path(&self, fs: &Arc<dyn Fs>) -> Result<PathBuf> {
|
fn current_path(&self, fs: &Arc<dyn Fs>) -> Result<PathBuf> {
|
||||||
let fs = fs.as_fake();
|
let fs = fs.as_fake();
|
||||||
let state = fs.state.lock();
|
let mut state = fs.state.lock();
|
||||||
let Some(target) = state.moves.get(&self.inode) else {
|
let Some(target) = state.moves.get(&self.inode).cloned() else {
|
||||||
anyhow::bail!("fake fd not moved")
|
anyhow::bail!("fake fd not moved")
|
||||||
};
|
};
|
||||||
|
|
||||||
if state.try_read_path(&target, false).is_some() {
|
if state.try_entry(&target, false).is_some() {
|
||||||
return Ok(target.clone());
|
return Ok(target.clone());
|
||||||
}
|
}
|
||||||
anyhow::bail!("fake fd target not found")
|
anyhow::bail!("fake fd target not found")
|
||||||
|
@ -1888,13 +1985,13 @@ impl Fs for FakeFs {
|
||||||
state.write_path(&cur_path, |entry| {
|
state.write_path(&cur_path, |entry| {
|
||||||
entry.or_insert_with(|| {
|
entry.or_insert_with(|| {
|
||||||
created_dirs.push((cur_path.clone(), Some(PathEventKind::Created)));
|
created_dirs.push((cur_path.clone(), Some(PathEventKind::Created)));
|
||||||
Arc::new(Mutex::new(FakeFsEntry::Dir {
|
FakeFsEntry::Dir {
|
||||||
inode,
|
inode,
|
||||||
mtime,
|
mtime,
|
||||||
len: 0,
|
len: 0,
|
||||||
entries: Default::default(),
|
entries: Default::default(),
|
||||||
git_repo_state: None,
|
git_repo_state: None,
|
||||||
}))
|
}
|
||||||
});
|
});
|
||||||
Ok(())
|
Ok(())
|
||||||
})?
|
})?
|
||||||
|
@ -1909,13 +2006,13 @@ impl Fs for FakeFs {
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let inode = state.get_and_increment_inode();
|
let inode = state.get_and_increment_inode();
|
||||||
let mtime = state.get_and_increment_mtime();
|
let mtime = state.get_and_increment_mtime();
|
||||||
let file = Arc::new(Mutex::new(FakeFsEntry::File {
|
let file = FakeFsEntry::File {
|
||||||
inode,
|
inode,
|
||||||
mtime,
|
mtime,
|
||||||
len: 0,
|
len: 0,
|
||||||
content: Vec::new(),
|
content: Vec::new(),
|
||||||
git_dir_path: None,
|
git_dir_path: None,
|
||||||
}));
|
};
|
||||||
let mut kind = Some(PathEventKind::Created);
|
let mut kind = Some(PathEventKind::Created);
|
||||||
state.write_path(path, |entry| {
|
state.write_path(path, |entry| {
|
||||||
match entry {
|
match entry {
|
||||||
|
@ -1939,7 +2036,7 @@ impl Fs for FakeFs {
|
||||||
|
|
||||||
async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> {
|
async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> {
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
|
let file = FakeFsEntry::Symlink { target };
|
||||||
state
|
state
|
||||||
.write_path(path.as_ref(), move |e| match e {
|
.write_path(path.as_ref(), move |e| match e {
|
||||||
btree_map::Entry::Vacant(e) => {
|
btree_map::Entry::Vacant(e) => {
|
||||||
|
@ -2002,7 +2099,7 @@ impl Fs for FakeFs {
|
||||||
}
|
}
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let inode = match *moved_entry.lock() {
|
let inode = match moved_entry {
|
||||||
FakeFsEntry::File { inode, .. } => inode,
|
FakeFsEntry::File { inode, .. } => inode,
|
||||||
FakeFsEntry::Dir { inode, .. } => inode,
|
FakeFsEntry::Dir { inode, .. } => inode,
|
||||||
_ => 0,
|
_ => 0,
|
||||||
|
@ -2051,8 +2148,8 @@ impl Fs for FakeFs {
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let mtime = state.get_and_increment_mtime();
|
let mtime = state.get_and_increment_mtime();
|
||||||
let inode = state.get_and_increment_inode();
|
let inode = state.get_and_increment_inode();
|
||||||
let source_entry = state.read_path(&source)?;
|
let source_entry = state.entry(&source)?;
|
||||||
let content = source_entry.lock().file_content(&source)?.clone();
|
let content = source_entry.file_content(&source)?.clone();
|
||||||
let mut kind = Some(PathEventKind::Created);
|
let mut kind = Some(PathEventKind::Created);
|
||||||
state.write_path(&target, |e| match e {
|
state.write_path(&target, |e| match e {
|
||||||
btree_map::Entry::Occupied(e) => {
|
btree_map::Entry::Occupied(e) => {
|
||||||
|
@ -2066,13 +2163,13 @@ impl Fs for FakeFs {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
btree_map::Entry::Vacant(e) => Ok(Some(
|
btree_map::Entry::Vacant(e) => Ok(Some(
|
||||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
e.insert(FakeFsEntry::File {
|
||||||
inode,
|
inode,
|
||||||
mtime,
|
mtime,
|
||||||
len: content.len() as u64,
|
len: content.len() as u64,
|
||||||
content,
|
content,
|
||||||
git_dir_path: None,
|
git_dir_path: None,
|
||||||
})))
|
})
|
||||||
.clone(),
|
.clone(),
|
||||||
)),
|
)),
|
||||||
})?;
|
})?;
|
||||||
|
@ -2088,8 +2185,7 @@ impl Fs for FakeFs {
|
||||||
let base_name = path.file_name().context("cannot remove the root")?;
|
let base_name = path.file_name().context("cannot remove the root")?;
|
||||||
|
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let parent_entry = state.read_path(parent_path)?;
|
let parent_entry = state.entry(parent_path)?;
|
||||||
let mut parent_entry = parent_entry.lock();
|
|
||||||
let entry = parent_entry
|
let entry = parent_entry
|
||||||
.dir_entries(parent_path)?
|
.dir_entries(parent_path)?
|
||||||
.entry(base_name.to_str().unwrap().into());
|
.entry(base_name.to_str().unwrap().into());
|
||||||
|
@ -2100,15 +2196,14 @@ impl Fs for FakeFs {
|
||||||
anyhow::bail!("{path:?} does not exist");
|
anyhow::bail!("{path:?} does not exist");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
btree_map::Entry::Occupied(e) => {
|
btree_map::Entry::Occupied(mut entry) => {
|
||||||
{
|
{
|
||||||
let mut entry = e.get().lock();
|
let children = entry.get_mut().dir_entries(&path)?;
|
||||||
let children = entry.dir_entries(&path)?;
|
|
||||||
if !options.recursive && !children.is_empty() {
|
if !options.recursive && !children.is_empty() {
|
||||||
anyhow::bail!("{path:?} is not empty");
|
anyhow::bail!("{path:?} is not empty");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
e.remove();
|
entry.remove();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.emit_event([(path, Some(PathEventKind::Removed))]);
|
state.emit_event([(path, Some(PathEventKind::Removed))]);
|
||||||
|
@ -2122,8 +2217,7 @@ impl Fs for FakeFs {
|
||||||
let parent_path = path.parent().context("cannot remove the root")?;
|
let parent_path = path.parent().context("cannot remove the root")?;
|
||||||
let base_name = path.file_name().unwrap();
|
let base_name = path.file_name().unwrap();
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let parent_entry = state.read_path(parent_path)?;
|
let parent_entry = state.entry(parent_path)?;
|
||||||
let mut parent_entry = parent_entry.lock();
|
|
||||||
let entry = parent_entry
|
let entry = parent_entry
|
||||||
.dir_entries(parent_path)?
|
.dir_entries(parent_path)?
|
||||||
.entry(base_name.to_str().unwrap().into());
|
.entry(base_name.to_str().unwrap().into());
|
||||||
|
@ -2133,9 +2227,9 @@ impl Fs for FakeFs {
|
||||||
anyhow::bail!("{path:?} does not exist");
|
anyhow::bail!("{path:?} does not exist");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
btree_map::Entry::Occupied(e) => {
|
btree_map::Entry::Occupied(mut entry) => {
|
||||||
e.get().lock().file_content(&path)?;
|
entry.get_mut().file_content(&path)?;
|
||||||
e.remove();
|
entry.remove();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.emit_event([(path, Some(PathEventKind::Removed))]);
|
state.emit_event([(path, Some(PathEventKind::Removed))]);
|
||||||
|
@ -2149,12 +2243,10 @@ impl Fs for FakeFs {
|
||||||
|
|
||||||
async fn open_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>> {
|
async fn open_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>> {
|
||||||
self.simulate_random_delay().await;
|
self.simulate_random_delay().await;
|
||||||
let state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let entry = state.read_path(&path)?;
|
let inode = match state.entry(&path)? {
|
||||||
let entry = entry.lock();
|
FakeFsEntry::File { inode, .. } => *inode,
|
||||||
let inode = match *entry {
|
FakeFsEntry::Dir { inode, .. } => *inode,
|
||||||
FakeFsEntry::File { inode, .. } => inode,
|
|
||||||
FakeFsEntry::Dir { inode, .. } => inode,
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
Ok(Arc::new(FakeHandle { inode }))
|
Ok(Arc::new(FakeHandle { inode }))
|
||||||
|
@ -2204,8 +2296,8 @@ impl Fs for FakeFs {
|
||||||
let path = normalize_path(path);
|
let path = normalize_path(path);
|
||||||
self.simulate_random_delay().await;
|
self.simulate_random_delay().await;
|
||||||
let state = self.state.lock();
|
let state = self.state.lock();
|
||||||
let (_, canonical_path) = state
|
let canonical_path = state
|
||||||
.try_read_path(&path, true)
|
.canonicalize(&path, true)
|
||||||
.with_context(|| format!("path does not exist: {path:?}"))?;
|
.with_context(|| format!("path does not exist: {path:?}"))?;
|
||||||
Ok(canonical_path)
|
Ok(canonical_path)
|
||||||
}
|
}
|
||||||
|
@ -2213,9 +2305,9 @@ impl Fs for FakeFs {
|
||||||
async fn is_file(&self, path: &Path) -> bool {
|
async fn is_file(&self, path: &Path) -> bool {
|
||||||
let path = normalize_path(path);
|
let path = normalize_path(path);
|
||||||
self.simulate_random_delay().await;
|
self.simulate_random_delay().await;
|
||||||
let state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
if let Some((entry, _)) = state.try_read_path(&path, true) {
|
if let Some((entry, _)) = state.try_entry(&path, true) {
|
||||||
entry.lock().is_file()
|
entry.is_file()
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
@ -2232,17 +2324,16 @@ impl Fs for FakeFs {
|
||||||
let path = normalize_path(path);
|
let path = normalize_path(path);
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
state.metadata_call_count += 1;
|
state.metadata_call_count += 1;
|
||||||
if let Some((mut entry, _)) = state.try_read_path(&path, false) {
|
if let Some((mut entry, _)) = state.try_entry(&path, false) {
|
||||||
let is_symlink = entry.lock().is_symlink();
|
let is_symlink = entry.is_symlink();
|
||||||
if is_symlink {
|
if is_symlink {
|
||||||
if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) {
|
if let Some(e) = state.try_entry(&path, true).map(|e| e.0) {
|
||||||
entry = e;
|
entry = e;
|
||||||
} else {
|
} else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let entry = entry.lock();
|
|
||||||
Ok(Some(match &*entry {
|
Ok(Some(match &*entry {
|
||||||
FakeFsEntry::File {
|
FakeFsEntry::File {
|
||||||
inode, mtime, len, ..
|
inode, mtime, len, ..
|
||||||
|
@ -2274,12 +2365,11 @@ impl Fs for FakeFs {
|
||||||
async fn read_link(&self, path: &Path) -> Result<PathBuf> {
|
async fn read_link(&self, path: &Path) -> Result<PathBuf> {
|
||||||
self.simulate_random_delay().await;
|
self.simulate_random_delay().await;
|
||||||
let path = normalize_path(path);
|
let path = normalize_path(path);
|
||||||
let state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
let (entry, _) = state
|
let (entry, _) = state
|
||||||
.try_read_path(&path, false)
|
.try_entry(&path, false)
|
||||||
.with_context(|| format!("path does not exist: {path:?}"))?;
|
.with_context(|| format!("path does not exist: {path:?}"))?;
|
||||||
let entry = entry.lock();
|
if let FakeFsEntry::Symlink { target } = entry {
|
||||||
if let FakeFsEntry::Symlink { target } = &*entry {
|
|
||||||
Ok(target.clone())
|
Ok(target.clone())
|
||||||
} else {
|
} else {
|
||||||
anyhow::bail!("not a symlink: {path:?}")
|
anyhow::bail!("not a symlink: {path:?}")
|
||||||
|
@ -2294,8 +2384,7 @@ impl Fs for FakeFs {
|
||||||
let path = normalize_path(path);
|
let path = normalize_path(path);
|
||||||
let mut state = self.state.lock();
|
let mut state = self.state.lock();
|
||||||
state.read_dir_call_count += 1;
|
state.read_dir_call_count += 1;
|
||||||
let entry = state.read_path(&path)?;
|
let entry = state.entry(&path)?;
|
||||||
let mut entry = entry.lock();
|
|
||||||
let children = entry.dir_entries(&path)?;
|
let children = entry.dir_entries(&path)?;
|
||||||
let paths = children
|
let paths = children
|
||||||
.keys()
|
.keys()
|
||||||
|
@ -2359,6 +2448,7 @@ impl Fs for FakeFs {
|
||||||
dot_git_path: abs_dot_git.to_path_buf(),
|
dot_git_path: abs_dot_git.to_path_buf(),
|
||||||
repository_dir_path: repository_dir_path.to_owned(),
|
repository_dir_path: repository_dir_path.to_owned(),
|
||||||
common_dir_path: common_dir_path.to_owned(),
|
common_dir_path: common_dir_path.to_owned(),
|
||||||
|
checkpoints: Arc::default(),
|
||||||
}) as _
|
}) as _
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,7 +12,7 @@ workspace = true
|
||||||
path = "src/git.rs"
|
path = "src/git.rs"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
test-support = []
|
test-support = ["rand"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
@ -26,6 +26,7 @@ http_client.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
regex.workspace = true
|
regex.workspace = true
|
||||||
|
rand = { workspace = true, optional = true }
|
||||||
rope.workspace = true
|
rope.workspace = true
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
|
@ -47,3 +48,4 @@ text = { workspace = true, features = ["test-support"] }
|
||||||
unindent.workspace = true
|
unindent.workspace = true
|
||||||
gpui = { workspace = true, features = ["test-support"] }
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
tempfile.workspace = true
|
tempfile.workspace = true
|
||||||
|
rand.workspace = true
|
||||||
|
|
|
@ -119,6 +119,13 @@ impl Oid {
|
||||||
Ok(Self(oid))
|
Ok(Self(oid))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub fn random(rng: &mut impl rand::Rng) -> Self {
|
||||||
|
let mut bytes = [0; 20];
|
||||||
|
rng.fill(&mut bytes);
|
||||||
|
Self::from_bytes(&bytes).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn as_bytes(&self) -> &[u8] {
|
pub fn as_bytes(&self) -> &[u8] {
|
||||||
self.0.as_bytes()
|
self.0.as_bytes()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue