agent2: Initial infra for checkpoints and message editing (#36120)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
f4b0332f78
commit
23cd5b59b2
17 changed files with 1374 additions and 582 deletions
|
@ -36,6 +36,7 @@ terminal.workspace = true
|
|||
ui.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
watch.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
|
|
|
@ -9,18 +9,19 @@ pub use mention::*;
|
|||
pub use terminal::*;
|
||||
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::{Context as _, Result};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use editor::Bias;
|
||||
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
|
||||
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||
use itertools::Itertools;
|
||||
use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
|
||||
use markdown::Markdown;
|
||||
use project::{AgentLocation, Project};
|
||||
use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::fmt::Formatter;
|
||||
use std::fmt::{Formatter, Write};
|
||||
use std::ops::Range;
|
||||
use std::process::ExitStatus;
|
||||
use std::rc::Rc;
|
||||
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
||||
|
@ -29,24 +30,23 @@ use util::ResultExt;
|
|||
|
||||
#[derive(Debug)]
|
||||
pub struct UserMessage {
|
||||
pub id: Option<UserMessageId>,
|
||||
pub content: ContentBlock,
|
||||
pub checkpoint: Option<GitStoreCheckpoint>,
|
||||
}
|
||||
|
||||
impl UserMessage {
|
||||
pub fn from_acp(
|
||||
message: impl IntoIterator<Item = acp::ContentBlock>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let mut content = ContentBlock::Empty;
|
||||
for chunk in message {
|
||||
content.append(chunk, &language_registry, cx)
|
||||
}
|
||||
Self { content: content }
|
||||
}
|
||||
|
||||
fn to_markdown(&self, cx: &App) -> String {
|
||||
format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
|
||||
let mut markdown = String::new();
|
||||
if let Some(_) = self.checkpoint {
|
||||
writeln!(markdown, "## User (checkpoint)").unwrap();
|
||||
} else {
|
||||
writeln!(markdown, "## User").unwrap();
|
||||
}
|
||||
writeln!(markdown).unwrap();
|
||||
writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
|
||||
writeln!(markdown).unwrap();
|
||||
markdown
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -633,6 +633,7 @@ pub struct AcpThread {
|
|||
pub enum AcpThreadEvent {
|
||||
NewEntry,
|
||||
EntryUpdated(usize),
|
||||
EntriesRemoved(Range<usize>),
|
||||
ToolAuthorizationRequired,
|
||||
Stopped,
|
||||
Error,
|
||||
|
@ -772,7 +773,7 @@ impl AcpThread {
|
|||
) -> Result<()> {
|
||||
match update {
|
||||
acp::SessionUpdate::UserMessageChunk { content } => {
|
||||
self.push_user_content_block(content, cx);
|
||||
self.push_user_content_block(None, content, cx);
|
||||
}
|
||||
acp::SessionUpdate::AgentMessageChunk { content } => {
|
||||
self.push_assistant_content_block(content, false, cx);
|
||||
|
@ -793,18 +794,32 @@ impl AcpThread {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
|
||||
pub fn push_user_content_block(
|
||||
&mut self,
|
||||
message_id: Option<UserMessageId>,
|
||||
chunk: acp::ContentBlock,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let language_registry = self.project.read(cx).languages().clone();
|
||||
let entries_len = self.entries.len();
|
||||
|
||||
if let Some(last_entry) = self.entries.last_mut()
|
||||
&& let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
|
||||
&& let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry
|
||||
{
|
||||
*id = message_id.or(id.take());
|
||||
content.append(chunk, &language_registry, cx);
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
|
||||
let idx = entries_len - 1;
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(idx));
|
||||
} else {
|
||||
let content = ContentBlock::new(chunk, &language_registry, cx);
|
||||
self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
|
||||
self.push_entry(
|
||||
AgentThreadEntry::UserMessage(UserMessage {
|
||||
id: message_id,
|
||||
content,
|
||||
checkpoint: None,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -819,7 +834,8 @@ impl AcpThread {
|
|||
if let Some(last_entry) = self.entries.last_mut()
|
||||
&& let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
|
||||
{
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
|
||||
let idx = entries_len - 1;
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(idx));
|
||||
match (chunks.last_mut(), is_thought) {
|
||||
(Some(AssistantMessageChunk::Message { block }), false)
|
||||
| (Some(AssistantMessageChunk::Thought { block }), true) => {
|
||||
|
@ -1118,69 +1134,113 @@ impl AcpThread {
|
|||
self.project.read(cx).languages().clone(),
|
||||
cx,
|
||||
);
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
|
||||
let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
|
||||
let message_id = if self
|
||||
.connection
|
||||
.session_editor(&self.session_id, cx)
|
||||
.is_some()
|
||||
{
|
||||
Some(UserMessageId::new())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.push_entry(
|
||||
AgentThreadEntry::UserMessage(UserMessage { content: block }),
|
||||
AgentThreadEntry::UserMessage(UserMessage {
|
||||
id: message_id.clone(),
|
||||
content: block,
|
||||
checkpoint: None,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
self.clear_completed_plan_entries(cx);
|
||||
|
||||
let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let cancel_task = self.cancel(cx);
|
||||
let request = acp::PromptRequest {
|
||||
prompt: message,
|
||||
session_id: self.session_id.clone(),
|
||||
};
|
||||
|
||||
self.send_task = Some(cx.spawn(async move |this, cx| {
|
||||
async {
|
||||
self.send_task = Some(cx.spawn({
|
||||
let message_id = message_id.clone();
|
||||
async move |this, cx| {
|
||||
cancel_task.await;
|
||||
|
||||
let result = this
|
||||
.update(cx, |this, cx| {
|
||||
this.connection.prompt(
|
||||
acp::PromptRequest {
|
||||
prompt: message,
|
||||
session_id: this.session_id.clone(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
tx.send(result).log_err();
|
||||
|
||||
anyhow::Ok(())
|
||||
old_checkpoint_tx.send(old_checkpoint.await).ok();
|
||||
if let Ok(result) = this.update(cx, |this, cx| {
|
||||
this.connection.prompt(message_id, request, cx)
|
||||
}) {
|
||||
tx.send(result.await).log_err();
|
||||
}
|
||||
}
|
||||
.await
|
||||
.log_err();
|
||||
}));
|
||||
|
||||
cx.spawn(async move |this, cx| match rx.await {
|
||||
Ok(Err(e)) => {
|
||||
this.update(cx, |this, cx| {
|
||||
this.send_task.take();
|
||||
cx.emit(AcpThreadEvent::Error)
|
||||
})
|
||||
cx.spawn(async move |this, cx| {
|
||||
let old_checkpoint = old_checkpoint_rx
|
||||
.await
|
||||
.map_err(|_| anyhow!("send canceled"))
|
||||
.flatten()
|
||||
.context("failed to get old checkpoint")
|
||||
.log_err();
|
||||
Err(e)?
|
||||
}
|
||||
result => {
|
||||
let cancelled = matches!(
|
||||
result,
|
||||
Ok(Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::Cancelled
|
||||
}))
|
||||
);
|
||||
|
||||
// We only take the task if the current prompt wasn't cancelled.
|
||||
//
|
||||
// This prompt may have been cancelled because another one was sent
|
||||
// while it was still generating. In these cases, dropping `send_task`
|
||||
// would cause the next generation to be cancelled.
|
||||
if !cancelled {
|
||||
this.update(cx, |this, _cx| this.send_task.take()).ok();
|
||||
}
|
||||
let response = rx.await;
|
||||
|
||||
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
|
||||
if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
|
||||
let new_checkpoint = git_store
|
||||
.update(cx, |git, cx| git.checkpoint(cx))?
|
||||
.await
|
||||
.context("failed to get new checkpoint")
|
||||
.log_err();
|
||||
Ok(())
|
||||
if let Some(new_checkpoint) = new_checkpoint {
|
||||
let equal = git_store
|
||||
.update(cx, |git, cx| {
|
||||
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
|
||||
})?
|
||||
.await
|
||||
.unwrap_or(true);
|
||||
if !equal {
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some((ix, message)) = this.user_message_mut(&message_id) {
|
||||
message.checkpoint = Some(old_checkpoint);
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
match response {
|
||||
Ok(Err(e)) => {
|
||||
this.send_task.take();
|
||||
cx.emit(AcpThreadEvent::Error);
|
||||
Err(e)
|
||||
}
|
||||
result => {
|
||||
let cancelled = matches!(
|
||||
result,
|
||||
Ok(Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::Cancelled
|
||||
}))
|
||||
);
|
||||
|
||||
// We only take the task if the current prompt wasn't cancelled.
|
||||
//
|
||||
// This prompt may have been cancelled because another one was sent
|
||||
// while it was still generating. In these cases, dropping `send_task`
|
||||
// would cause the next generation to be cancelled.
|
||||
if !cancelled {
|
||||
this.send_task.take();
|
||||
}
|
||||
|
||||
cx.emit(AcpThreadEvent::Stopped);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
})?
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -1212,6 +1272,66 @@ impl AcpThread {
|
|||
cx.foreground_executor().spawn(send_task)
|
||||
}
|
||||
|
||||
/// Rewinds this thread to before the entry at `index`, removing it and all
|
||||
/// subsequent entries while reverting any changes made from that point.
|
||||
pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
|
||||
return Task::ready(Err(anyhow!("not supported")));
|
||||
};
|
||||
let Some(message) = self.user_message(&id) else {
|
||||
return Task::ready(Err(anyhow!("message not found")));
|
||||
};
|
||||
|
||||
let checkpoint = message.checkpoint.clone();
|
||||
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(checkpoint) = checkpoint {
|
||||
git_store
|
||||
.update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
|
||||
.await?;
|
||||
}
|
||||
|
||||
cx.update(|cx| session_editor.truncate(id.clone(), cx))?
|
||||
.await?;
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some((ix, _)) = this.user_message_mut(&id) {
|
||||
let range = ix..this.entries.len();
|
||||
this.entries.truncate(ix);
|
||||
cx.emit(AcpThreadEvent::EntriesRemoved(range));
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
|
||||
self.entries.iter().find_map(|entry| {
|
||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||
if message.id.as_ref() == Some(&id) {
|
||||
Some(message)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
|
||||
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
|
||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||
if message.id.as_ref() == Some(&id) {
|
||||
Some((ix, message))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn read_text_file(
|
||||
&self,
|
||||
path: PathBuf,
|
||||
|
@ -1414,13 +1534,18 @@ mod tests {
|
|||
use futures::{channel::mpsc, future::LocalBoxFuture, select};
|
||||
use gpui::{AsyncApp, TestAppContext, WeakEntity};
|
||||
use indoc::indoc;
|
||||
use project::FakeFs;
|
||||
use project::{FakeFs, Fs};
|
||||
use rand::Rng as _;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
|
||||
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
path::Path,
|
||||
rc::Rc,
|
||||
sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
|
||||
time::Duration,
|
||||
};
|
||||
use util::path;
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
|
@ -1452,6 +1577,7 @@ mod tests {
|
|||
// Test creating a new user message
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.push_user_content_block(
|
||||
None,
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
annotations: None,
|
||||
text: "Hello, ".to_string(),
|
||||
|
@ -1463,6 +1589,7 @@ mod tests {
|
|||
thread.update(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries.len(), 1);
|
||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
||||
assert_eq!(user_msg.id, None);
|
||||
assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
|
||||
} else {
|
||||
panic!("Expected UserMessage");
|
||||
|
@ -1470,8 +1597,10 @@ mod tests {
|
|||
});
|
||||
|
||||
// Test appending to existing user message
|
||||
let message_1_id = UserMessageId::new();
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.push_user_content_block(
|
||||
Some(message_1_id.clone()),
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
annotations: None,
|
||||
text: "world!".to_string(),
|
||||
|
@ -1483,6 +1612,7 @@ mod tests {
|
|||
thread.update(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries.len(), 1);
|
||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
||||
assert_eq!(user_msg.id, Some(message_1_id));
|
||||
assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
|
||||
} else {
|
||||
panic!("Expected UserMessage");
|
||||
|
@ -1501,8 +1631,10 @@ mod tests {
|
|||
);
|
||||
});
|
||||
|
||||
let message_2_id = UserMessageId::new();
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.push_user_content_block(
|
||||
Some(message_2_id.clone()),
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
annotations: None,
|
||||
text: "New user message".to_string(),
|
||||
|
@ -1514,6 +1646,7 @@ mod tests {
|
|||
thread.update(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries.len(), 3);
|
||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
|
||||
assert_eq!(user_msg.id, Some(message_2_id));
|
||||
assert_eq!(user_msg.content.to_markdown(cx), "New user message");
|
||||
} else {
|
||||
panic!("Expected UserMessage at index 2");
|
||||
|
@ -1830,6 +1963,180 @@ mod tests {
|
|||
assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_checkpoints(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.insert_tree(
|
||||
path!("/test"),
|
||||
json!({
|
||||
".git": {}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||
|
||||
let simulate_changes = Arc::new(AtomicBool::new(true));
|
||||
let next_filename = Arc::new(AtomicUsize::new(0));
|
||||
let connection = Rc::new(FakeAgentConnection::new().on_user_message({
|
||||
let simulate_changes = simulate_changes.clone();
|
||||
let next_filename = next_filename.clone();
|
||||
let fs = fs.clone();
|
||||
move |request, thread, mut cx| {
|
||||
let fs = fs.clone();
|
||||
let simulate_changes = simulate_changes.clone();
|
||||
let next_filename = next_filename.clone();
|
||||
async move {
|
||||
if simulate_changes.load(SeqCst) {
|
||||
let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
|
||||
fs.write(Path::new(&filename), b"").await?;
|
||||
}
|
||||
|
||||
let acp::ContentBlock::Text(content) = &request.prompt[0] else {
|
||||
panic!("expected text content block");
|
||||
};
|
||||
thread.update(&mut cx, |thread, cx| {
|
||||
thread
|
||||
.handle_session_update(
|
||||
acp::SessionUpdate::AgentMessageChunk {
|
||||
content: content.text.to_uppercase().into(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
})?;
|
||||
Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
}));
|
||||
let thread = connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
indoc! {"
|
||||
## User (checkpoint)
|
||||
|
||||
Lorem
|
||||
|
||||
## Assistant
|
||||
|
||||
LOREM
|
||||
|
||||
"}
|
||||
);
|
||||
});
|
||||
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
|
||||
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
indoc! {"
|
||||
## User (checkpoint)
|
||||
|
||||
Lorem
|
||||
|
||||
## Assistant
|
||||
|
||||
LOREM
|
||||
|
||||
## User (checkpoint)
|
||||
|
||||
ipsum
|
||||
|
||||
## Assistant
|
||||
|
||||
IPSUM
|
||||
|
||||
"}
|
||||
);
|
||||
});
|
||||
assert_eq!(
|
||||
fs.files(),
|
||||
vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
|
||||
);
|
||||
|
||||
// Checkpoint isn't stored when there are no changes.
|
||||
simulate_changes.store(false, SeqCst);
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
indoc! {"
|
||||
## User (checkpoint)
|
||||
|
||||
Lorem
|
||||
|
||||
## Assistant
|
||||
|
||||
LOREM
|
||||
|
||||
## User (checkpoint)
|
||||
|
||||
ipsum
|
||||
|
||||
## Assistant
|
||||
|
||||
IPSUM
|
||||
|
||||
## User
|
||||
|
||||
dolor
|
||||
|
||||
## Assistant
|
||||
|
||||
DOLOR
|
||||
|
||||
"}
|
||||
);
|
||||
});
|
||||
assert_eq!(
|
||||
fs.files(),
|
||||
vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
|
||||
);
|
||||
|
||||
// Rewinding the conversation truncates the history and restores the checkpoint.
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
|
||||
panic!("unexpected entries {:?}", thread.entries)
|
||||
};
|
||||
thread.rewind(message.id.clone().unwrap(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
indoc! {"
|
||||
## User (checkpoint)
|
||||
|
||||
Lorem
|
||||
|
||||
## Assistant
|
||||
|
||||
LOREM
|
||||
|
||||
"}
|
||||
);
|
||||
});
|
||||
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
|
||||
}
|
||||
|
||||
async fn run_until_first_tool_call(
|
||||
thread: &Entity<AcpThread>,
|
||||
cx: &mut TestAppContext,
|
||||
|
@ -1938,6 +2245,7 @@ mod tests {
|
|||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||
|
@ -1966,5 +2274,25 @@ mod tests {
|
|||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn session_editor(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||
Some(Rc::new(FakeAgentSessionEditor {
|
||||
_session_id: session_id.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
struct FakeAgentSessionEditor {
|
||||
_session_id: acp::SessionId,
|
||||
}
|
||||
|
||||
impl AgentSessionEditor for FakeAgentSessionEditor {
|
||||
fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,21 @@
|
|||
use std::{error::Error, fmt, path::Path, rc::Rc};
|
||||
|
||||
use crate::AcpThread;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
use gpui::{AsyncApp, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AcpThread;
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct UserMessageId(Arc<str>);
|
||||
|
||||
impl UserMessageId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4().to_string().into())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn new_thread(
|
||||
|
@ -21,11 +29,23 @@ pub trait AgentConnection {
|
|||
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
|
||||
-> Task<Result<acp::PromptResponse>>;
|
||||
fn prompt(
|
||||
&self,
|
||||
user_message_id: Option<UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
fn session_editor(
|
||||
&self,
|
||||
_session_id: &acp::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||
///
|
||||
/// If the agent does not support model selection, returns [None].
|
||||
|
@ -35,6 +55,10 @@ pub trait AgentConnection {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait AgentSessionEditor {
|
||||
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue