Compare commits
1 commit
main
...
acp-rewind
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c985789d84 |
13 changed files with 202 additions and 168 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -191,9 +191,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "agent-client-protocol"
|
name = "agent-client-protocol"
|
||||||
version = "0.0.31"
|
version = "0.0.32"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "289eb34ee17213dadcca47eedadd386a5e7678094095414e475965d1bcca2860"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-broadcast",
|
"async-broadcast",
|
||||||
|
|
|
@ -426,7 +426,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||||
# External crates
|
# External crates
|
||||||
#
|
#
|
||||||
|
|
||||||
agent-client-protocol = "0.0.31"
|
agent-client-protocol = { path = "../agent-client-protocol"}
|
||||||
aho-corasick = "1.1"
|
aho-corasick = "1.1"
|
||||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||||
any_vec = "0.14"
|
any_vec = "0.14"
|
||||||
|
|
|
@ -32,10 +32,15 @@ use std::time::{Duration, Instant};
|
||||||
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
||||||
use ui::App;
|
use ui::App;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
pub fn new_prompt_id() -> acp::PromptId {
|
||||||
|
acp::PromptId(Uuid::new_v4().to_string().into())
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct UserMessage {
|
pub struct UserMessage {
|
||||||
pub id: Option<UserMessageId>,
|
pub prompt_id: Option<acp::PromptId>,
|
||||||
pub content: ContentBlock,
|
pub content: ContentBlock,
|
||||||
pub chunks: Vec<acp::ContentBlock>,
|
pub chunks: Vec<acp::ContentBlock>,
|
||||||
pub checkpoint: Option<Checkpoint>,
|
pub checkpoint: Option<Checkpoint>,
|
||||||
|
@ -962,7 +967,7 @@ impl AcpThread {
|
||||||
|
|
||||||
pub fn push_user_content_block(
|
pub fn push_user_content_block(
|
||||||
&mut self,
|
&mut self,
|
||||||
message_id: Option<UserMessageId>,
|
prompt_id: Option<acp::PromptId>,
|
||||||
chunk: acp::ContentBlock,
|
chunk: acp::ContentBlock,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
|
@ -971,13 +976,13 @@ impl AcpThread {
|
||||||
|
|
||||||
if let Some(last_entry) = self.entries.last_mut()
|
if let Some(last_entry) = self.entries.last_mut()
|
||||||
&& let AgentThreadEntry::UserMessage(UserMessage {
|
&& let AgentThreadEntry::UserMessage(UserMessage {
|
||||||
id,
|
prompt_id: id,
|
||||||
content,
|
content,
|
||||||
chunks,
|
chunks,
|
||||||
..
|
..
|
||||||
}) = last_entry
|
}) = last_entry
|
||||||
{
|
{
|
||||||
*id = message_id.or(id.take());
|
*id = prompt_id.or(id.take());
|
||||||
content.append(chunk.clone(), &language_registry, cx);
|
content.append(chunk.clone(), &language_registry, cx);
|
||||||
chunks.push(chunk);
|
chunks.push(chunk);
|
||||||
let idx = entries_len - 1;
|
let idx = entries_len - 1;
|
||||||
|
@ -986,7 +991,7 @@ impl AcpThread {
|
||||||
let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
|
let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
|
||||||
self.push_entry(
|
self.push_entry(
|
||||||
AgentThreadEntry::UserMessage(UserMessage {
|
AgentThreadEntry::UserMessage(UserMessage {
|
||||||
id: message_id,
|
prompt_id,
|
||||||
content,
|
content,
|
||||||
chunks: vec![chunk],
|
chunks: vec![chunk],
|
||||||
checkpoint: None,
|
checkpoint: None,
|
||||||
|
@ -1336,6 +1341,7 @@ impl AcpThread {
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> BoxFuture<'static, Result<()>> {
|
) -> BoxFuture<'static, Result<()>> {
|
||||||
self.send(
|
self.send(
|
||||||
|
new_prompt_id(),
|
||||||
vec![acp::ContentBlock::Text(acp::TextContent {
|
vec![acp::ContentBlock::Text(acp::TextContent {
|
||||||
text: message.to_string(),
|
text: message.to_string(),
|
||||||
annotations: None,
|
annotations: None,
|
||||||
|
@ -1346,6 +1352,7 @@ impl AcpThread {
|
||||||
|
|
||||||
pub fn send(
|
pub fn send(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
prompt_id: acp::PromptId,
|
||||||
message: Vec<acp::ContentBlock>,
|
message: Vec<acp::ContentBlock>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> BoxFuture<'static, Result<()>> {
|
) -> BoxFuture<'static, Result<()>> {
|
||||||
|
@ -1355,22 +1362,17 @@ impl AcpThread {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
let request = acp::PromptRequest {
|
let request = acp::PromptRequest {
|
||||||
|
prompt_id: Some(prompt_id.clone()),
|
||||||
prompt: message.clone(),
|
prompt: message.clone(),
|
||||||
session_id: self.session_id.clone(),
|
session_id: self.session_id.clone(),
|
||||||
};
|
};
|
||||||
let git_store = self.project.read(cx).git_store().clone();
|
let git_store = self.project.read(cx).git_store().clone();
|
||||||
|
|
||||||
let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
|
|
||||||
Some(UserMessageId::new())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
self.run_turn(cx, async move |this, cx| {
|
self.run_turn(cx, async move |this, cx| {
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.push_entry(
|
this.push_entry(
|
||||||
AgentThreadEntry::UserMessage(UserMessage {
|
AgentThreadEntry::UserMessage(UserMessage {
|
||||||
id: message_id.clone(),
|
prompt_id: Some(prompt_id),
|
||||||
content: block,
|
content: block,
|
||||||
chunks: message,
|
chunks: message,
|
||||||
checkpoint: None,
|
checkpoint: None,
|
||||||
|
@ -1392,7 +1394,7 @@ impl AcpThread {
|
||||||
show: false,
|
show: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
this.connection.prompt(message_id, request, cx)
|
this.connection.prompt(request, cx)
|
||||||
})?
|
})?
|
||||||
.await
|
.await
|
||||||
})
|
})
|
||||||
|
@ -1509,8 +1511,8 @@ impl AcpThread {
|
||||||
|
|
||||||
/// Rewinds this thread to before the entry at `index`, removing it and all
|
/// Rewinds this thread to before the entry at `index`, removing it and all
|
||||||
/// subsequent entries while reverting any changes made from that point.
|
/// subsequent entries while reverting any changes made from that point.
|
||||||
pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
|
pub fn rewind(&mut self, id: acp::PromptId, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||||
let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
|
let Some(rewind) = self.connection.rewind(&self.session_id, cx) else {
|
||||||
return Task::ready(Err(anyhow!("not supported")));
|
return Task::ready(Err(anyhow!("not supported")));
|
||||||
};
|
};
|
||||||
let Some(message) = self.user_message(&id) else {
|
let Some(message) = self.user_message(&id) else {
|
||||||
|
@ -1530,7 +1532,7 @@ impl AcpThread {
|
||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
|
cx.update(|cx| rewind.rewind(id.clone(), cx))?.await?;
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
if let Some((ix, _)) = this.user_message_mut(&id) {
|
if let Some((ix, _)) = this.user_message_mut(&id) {
|
||||||
let range = ix..this.entries.len();
|
let range = ix..this.entries.len();
|
||||||
|
@ -1594,10 +1596,10 @@ impl AcpThread {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
|
fn user_message(&self, id: &acp::PromptId) -> Option<&UserMessage> {
|
||||||
self.entries.iter().find_map(|entry| {
|
self.entries.iter().find_map(|entry| {
|
||||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||||
if message.id.as_ref() == Some(id) {
|
if message.prompt_id.as_ref() == Some(id) {
|
||||||
Some(message)
|
Some(message)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -1608,10 +1610,10 @@ impl AcpThread {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
|
fn user_message_mut(&mut self, id: &acp::PromptId) -> Option<(usize, &mut UserMessage)> {
|
||||||
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
|
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
|
||||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||||
if message.id.as_ref() == Some(id) {
|
if message.prompt_id.as_ref() == Some(id) {
|
||||||
Some((ix, message))
|
Some((ix, message))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -1905,7 +1907,7 @@ mod tests {
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
assert_eq!(thread.entries.len(), 1);
|
assert_eq!(thread.entries.len(), 1);
|
||||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
||||||
assert_eq!(user_msg.id, None);
|
assert_eq!(user_msg.prompt_id, None);
|
||||||
assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
|
assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected UserMessage");
|
panic!("Expected UserMessage");
|
||||||
|
@ -1913,7 +1915,7 @@ mod tests {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Test appending to existing user message
|
// Test appending to existing user message
|
||||||
let message_1_id = UserMessageId::new();
|
let message_1_id = new_prompt_id();
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.push_user_content_block(
|
thread.push_user_content_block(
|
||||||
Some(message_1_id.clone()),
|
Some(message_1_id.clone()),
|
||||||
|
@ -1928,7 +1930,7 @@ mod tests {
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
assert_eq!(thread.entries.len(), 1);
|
assert_eq!(thread.entries.len(), 1);
|
||||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
||||||
assert_eq!(user_msg.id, Some(message_1_id));
|
assert_eq!(user_msg.prompt_id, Some(message_1_id));
|
||||||
assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
|
assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected UserMessage");
|
panic!("Expected UserMessage");
|
||||||
|
@ -1947,7 +1949,7 @@ mod tests {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
let message_2_id = UserMessageId::new();
|
let message_2_id = new_prompt_id();
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
thread.push_user_content_block(
|
thread.push_user_content_block(
|
||||||
Some(message_2_id.clone()),
|
Some(message_2_id.clone()),
|
||||||
|
@ -1962,7 +1964,7 @@ mod tests {
|
||||||
thread.update(cx, |thread, cx| {
|
thread.update(cx, |thread, cx| {
|
||||||
assert_eq!(thread.entries.len(), 3);
|
assert_eq!(thread.entries.len(), 3);
|
||||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
|
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
|
||||||
assert_eq!(user_msg.id, Some(message_2_id));
|
assert_eq!(user_msg.prompt_id, Some(message_2_id));
|
||||||
assert_eq!(user_msg.content.to_markdown(cx), "New user message");
|
assert_eq!(user_msg.content.to_markdown(cx), "New user message");
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected UserMessage at index 2");
|
panic!("Expected UserMessage at index 2");
|
||||||
|
@ -2259,9 +2261,13 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
|
cx.update(|cx| {
|
||||||
.await
|
thread.update(cx, |thread, cx| {
|
||||||
.unwrap();
|
thread.send(new_prompt_id(), vec!["Hi".into()], cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
|
assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
|
||||||
}
|
}
|
||||||
|
@ -2320,9 +2326,13 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
|
cx.update(|cx| {
|
||||||
.await
|
thread.update(cx, |thread, cx| {
|
||||||
.unwrap();
|
thread.send(new_prompt_id(), vec!["Lorem".into()], cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
thread.read_with(cx, |thread, cx| {
|
thread.read_with(cx, |thread, cx| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.to_markdown(cx),
|
thread.to_markdown(cx),
|
||||||
|
@ -2340,9 +2350,13 @@ mod tests {
|
||||||
});
|
});
|
||||||
assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
|
assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
|
||||||
|
|
||||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
|
cx.update(|cx| {
|
||||||
.await
|
thread.update(cx, |thread, cx| {
|
||||||
.unwrap();
|
thread.send(new_prompt_id(), vec!["ipsum".into()], cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
thread.read_with(cx, |thread, cx| {
|
thread.read_with(cx, |thread, cx| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.to_markdown(cx),
|
thread.to_markdown(cx),
|
||||||
|
@ -2376,9 +2390,13 @@ mod tests {
|
||||||
|
|
||||||
// Checkpoint isn't stored when there are no changes.
|
// Checkpoint isn't stored when there are no changes.
|
||||||
simulate_changes.store(false, SeqCst);
|
simulate_changes.store(false, SeqCst);
|
||||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
|
cx.update(|cx| {
|
||||||
.await
|
thread.update(cx, |thread, cx| {
|
||||||
.unwrap();
|
thread.send(new_prompt_id(), vec!["dolor".into()], cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
thread.read_with(cx, |thread, cx| {
|
thread.read_with(cx, |thread, cx| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.to_markdown(cx),
|
thread.to_markdown(cx),
|
||||||
|
@ -2424,7 +2442,7 @@ mod tests {
|
||||||
let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
|
let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
|
||||||
panic!("unexpected entries {:?}", thread.entries)
|
panic!("unexpected entries {:?}", thread.entries)
|
||||||
};
|
};
|
||||||
thread.rewind(message.id.clone().unwrap(), cx)
|
thread.rewind(message.prompt_id.clone().unwrap(), cx)
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -2490,9 +2508,13 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
|
cx.update(|cx| {
|
||||||
.await
|
thread.update(cx, |thread, cx| {
|
||||||
.unwrap();
|
thread.send(new_prompt_id(), vec!["hello".into()], cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
thread.read_with(cx, |thread, cx| {
|
thread.read_with(cx, |thread, cx| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.to_markdown(cx),
|
thread.to_markdown(cx),
|
||||||
|
@ -2512,9 +2534,13 @@ mod tests {
|
||||||
// Simulate refusing the second message, ensuring the conversation gets
|
// Simulate refusing the second message, ensuring the conversation gets
|
||||||
// truncated to before sending it.
|
// truncated to before sending it.
|
||||||
refuse_next.store(true, SeqCst);
|
refuse_next.store(true, SeqCst);
|
||||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
|
cx.update(|cx| {
|
||||||
.await
|
thread.update(cx, |thread, cx| {
|
||||||
.unwrap();
|
thread.send(new_prompt_id(), vec!["world".into()], cx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
thread.read_with(cx, |thread, cx| {
|
thread.read_with(cx, |thread, cx| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
thread.to_markdown(cx),
|
thread.to_markdown(cx),
|
||||||
|
@ -2653,7 +2679,6 @@ mod tests {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
_id: Option<UserMessageId>,
|
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||||
|
@ -2683,11 +2708,11 @@ mod tests {
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn truncate(
|
fn rewind(
|
||||||
&self,
|
&self,
|
||||||
session_id: &acp::SessionId,
|
session_id: &acp::SessionId,
|
||||||
_cx: &App,
|
_cx: &App,
|
||||||
) -> Option<Rc<dyn AgentSessionTruncate>> {
|
) -> Option<Rc<dyn AgentSessionRewind>> {
|
||||||
Some(Rc::new(FakeAgentSessionEditor {
|
Some(Rc::new(FakeAgentSessionEditor {
|
||||||
_session_id: session_id.clone(),
|
_session_id: session_id.clone(),
|
||||||
}))
|
}))
|
||||||
|
@ -2702,8 +2727,8 @@ mod tests {
|
||||||
_session_id: acp::SessionId,
|
_session_id: acp::SessionId,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentSessionTruncate for FakeAgentSessionEditor {
|
impl AgentSessionRewind for FakeAgentSessionEditor {
|
||||||
fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
|
fn rewind(&self, _message_id: acp::PromptId, _cx: &mut App) -> Task<Result<()>> {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,19 +5,8 @@ use collections::IndexMap;
|
||||||
use gpui::{Entity, SharedString, Task};
|
use gpui::{Entity, SharedString, Task};
|
||||||
use language_model::LanguageModelProviderId;
|
use language_model::LanguageModelProviderId;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use serde::{Deserialize, Serialize};
|
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc};
|
||||||
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
|
||||||
use ui::{App, IconName};
|
use ui::{App, IconName};
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
|
||||||
pub struct UserMessageId(Arc<str>);
|
|
||||||
|
|
||||||
impl UserMessageId {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self(Uuid::new_v4().to_string().into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait AgentConnection {
|
pub trait AgentConnection {
|
||||||
fn new_thread(
|
fn new_thread(
|
||||||
|
@ -31,12 +20,8 @@ pub trait AgentConnection {
|
||||||
|
|
||||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
|
||||||
&self,
|
-> Task<Result<acp::PromptResponse>>;
|
||||||
user_message_id: Option<UserMessageId>,
|
|
||||||
params: acp::PromptRequest,
|
|
||||||
cx: &mut App,
|
|
||||||
) -> Task<Result<acp::PromptResponse>>;
|
|
||||||
|
|
||||||
fn resume(
|
fn resume(
|
||||||
&self,
|
&self,
|
||||||
|
@ -48,11 +33,11 @@ pub trait AgentConnection {
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||||
|
|
||||||
fn truncate(
|
fn rewind(
|
||||||
&self,
|
&self,
|
||||||
_session_id: &acp::SessionId,
|
_session_id: &acp::SessionId,
|
||||||
_cx: &App,
|
_cx: &App,
|
||||||
) -> Option<Rc<dyn AgentSessionTruncate>> {
|
) -> Option<Rc<dyn AgentSessionRewind>> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,8 +70,8 @@ impl dyn AgentConnection {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait AgentSessionTruncate {
|
pub trait AgentSessionRewind {
|
||||||
fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
|
fn rewind(&self, message_id: acp::PromptId, cx: &mut App) -> Task<Result<()>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait AgentSessionResume {
|
pub trait AgentSessionResume {
|
||||||
|
@ -362,7 +347,6 @@ mod test_support {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
_id: Option<UserMessageId>,
|
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||||
|
@ -432,11 +416,11 @@ mod test_support {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn truncate(
|
fn rewind(
|
||||||
&self,
|
&self,
|
||||||
_session_id: &agent_client_protocol::SessionId,
|
_session_id: &agent_client_protocol::SessionId,
|
||||||
_cx: &App,
|
_cx: &App,
|
||||||
) -> Option<Rc<dyn AgentSessionTruncate>> {
|
) -> Option<Rc<dyn AgentSessionRewind>> {
|
||||||
Some(Rc::new(StubAgentSessionEditor))
|
Some(Rc::new(StubAgentSessionEditor))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -447,8 +431,8 @@ mod test_support {
|
||||||
|
|
||||||
struct StubAgentSessionEditor;
|
struct StubAgentSessionEditor;
|
||||||
|
|
||||||
impl AgentSessionTruncate for StubAgentSessionEditor {
|
impl AgentSessionRewind for StubAgentSessionEditor {
|
||||||
fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
|
fn rewind(&self, _: acp::PromptId, _: &mut App) -> Task<Result<()>> {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -905,11 +905,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
id: Option<acp_thread::UserMessageId>,
|
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
let id = id.expect("UserMessageId is required");
|
let id = params.prompt_id.expect("UserMessageId is required");
|
||||||
let session_id = params.session_id.clone();
|
let session_id = params.session_id.clone();
|
||||||
log::info!("Received prompt request for session: {}", session_id);
|
log::info!("Received prompt request for session: {}", session_id);
|
||||||
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
||||||
|
@ -948,11 +947,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn truncate(
|
fn rewind(
|
||||||
&self,
|
&self,
|
||||||
session_id: &agent_client_protocol::SessionId,
|
session_id: &agent_client_protocol::SessionId,
|
||||||
cx: &App,
|
cx: &App,
|
||||||
) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
|
) -> Option<Rc<dyn acp_thread::AgentSessionRewind>> {
|
||||||
self.0.read_with(cx, |agent, _cx| {
|
self.0.read_with(cx, |agent, _cx| {
|
||||||
agent.sessions.get(session_id).map(|session| {
|
agent.sessions.get(session_id).map(|session| {
|
||||||
Rc::new(NativeAgentSessionEditor {
|
Rc::new(NativeAgentSessionEditor {
|
||||||
|
@ -1009,10 +1008,10 @@ struct NativeAgentSessionEditor {
|
||||||
acp_thread: WeakEntity<AcpThread>,
|
acp_thread: WeakEntity<AcpThread>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl acp_thread::AgentSessionTruncate for NativeAgentSessionEditor {
|
impl acp_thread::AgentSessionRewind for NativeAgentSessionEditor {
|
||||||
fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
fn rewind(&self, message_id: acp::PromptId, cx: &mut App) -> Task<Result<()>> {
|
||||||
match self.thread.update(cx, |thread, cx| {
|
match self.thread.update(cx, |thread, cx| {
|
||||||
thread.truncate(message_id.clone(), cx)?;
|
thread.rewind(message_id.clone(), cx)?;
|
||||||
Ok(thread.latest_token_usage())
|
Ok(thread.latest_token_usage())
|
||||||
}) {
|
}) {
|
||||||
Ok(usage) => {
|
Ok(usage) => {
|
||||||
|
@ -1065,6 +1064,7 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use acp_thread::{
|
use acp_thread::{
|
||||||
AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
|
AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
|
||||||
|
new_prompt_id,
|
||||||
};
|
};
|
||||||
use fs::FakeFs;
|
use fs::FakeFs;
|
||||||
use gpui::TestAppContext;
|
use gpui::TestAppContext;
|
||||||
|
@ -1311,6 +1311,7 @@ mod tests {
|
||||||
|
|
||||||
let send = acp_thread.update(cx, |thread, cx| {
|
let send = acp_thread.update(cx, |thread, cx| {
|
||||||
thread.send(
|
thread.send(
|
||||||
|
new_prompt_id(),
|
||||||
vec![
|
vec![
|
||||||
"What does ".into(),
|
"What does ".into(),
|
||||||
acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
|
||||||
use acp_thread::UserMessageId;
|
use acp_thread::new_prompt_id;
|
||||||
use agent::{thread::DetailedSummaryState, thread_store};
|
use agent::{thread::DetailedSummaryState, thread_store};
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use agent_settings::{AgentProfileId, CompletionMode};
|
use agent_settings::{AgentProfileId, CompletionMode};
|
||||||
|
@ -43,7 +43,7 @@ pub struct DbThread {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub cumulative_token_usage: language_model::TokenUsage,
|
pub cumulative_token_usage: language_model::TokenUsage,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>,
|
pub request_token_usage: HashMap<acp::PromptId, language_model::TokenUsage>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub model: Option<DbLanguageModel>,
|
pub model: Option<DbLanguageModel>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -97,7 +97,7 @@ impl DbThread {
|
||||||
content.push(UserMessageContent::Text(msg.context));
|
content.push(UserMessageContent::Text(msg.context));
|
||||||
}
|
}
|
||||||
|
|
||||||
let id = UserMessageId::new();
|
let id = new_prompt_id();
|
||||||
last_user_message_id = Some(id.clone());
|
last_user_message_id = Some(id.clone());
|
||||||
|
|
||||||
crate::Message::User(UserMessage {
|
crate::Message::User(UserMessage {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
|
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, new_prompt_id};
|
||||||
use agent_client_protocol::{self as acp};
|
use agent_client_protocol::{self as acp};
|
||||||
use agent_settings::AgentProfileId;
|
use agent_settings::AgentProfileId;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
@ -48,7 +48,7 @@ async fn test_echo(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let events = thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
|
thread.send(new_prompt_id(), ["Testing: Reply with 'Hello'"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -79,7 +79,7 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
||||||
let events = thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(
|
thread.send(
|
||||||
UserMessageId::new(),
|
new_prompt_id(),
|
||||||
[indoc! {"
|
[indoc! {"
|
||||||
Testing:
|
Testing:
|
||||||
|
|
||||||
|
@ -130,9 +130,7 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx))
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
||||||
})
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let mut pending_completions = fake_model.pending_completions();
|
let mut pending_completions = fake_model.pending_completions();
|
||||||
|
@ -168,7 +166,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
|
||||||
// Send initial user message and verify it's cached
|
// Send initial user message and verify it's cached
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Message 1"], cx)
|
thread.send(new_prompt_id(), ["Message 1"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -191,7 +189,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
|
||||||
// Send another user message and verify only the latest is cached
|
// Send another user message and verify only the latest is cached
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Message 2"], cx)
|
thread.send(new_prompt_id(), ["Message 2"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -227,7 +225,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
|
||||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
|
thread.send(new_prompt_id(), ["Use the echo tool"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -304,7 +302,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(
|
thread.send(
|
||||||
UserMessageId::new(),
|
new_prompt_id(),
|
||||||
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
|
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -320,7 +318,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||||
thread.remove_tool(&EchoTool::name());
|
thread.remove_tool(&EchoTool::name());
|
||||||
thread.add_tool(DelayTool);
|
thread.add_tool(DelayTool);
|
||||||
thread.send(
|
thread.send(
|
||||||
UserMessageId::new(),
|
new_prompt_id(),
|
||||||
[
|
[
|
||||||
"Now call the delay tool with 200ms.",
|
"Now call the delay tool with 200ms.",
|
||||||
"When the timer goes off, then you echo the output of the tool.",
|
"When the timer goes off, then you echo the output of the tool.",
|
||||||
|
@ -363,7 +361,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(WordListTool);
|
thread.add_tool(WordListTool);
|
||||||
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
|
thread.send(new_prompt_id(), ["Test the word_list tool."], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -414,7 +412,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(ToolRequiringPermission);
|
thread.add_tool(ToolRequiringPermission);
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
thread.send(new_prompt_id(), ["abc"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -544,9 +542,7 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx))
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
|
||||||
})
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||||
|
@ -575,7 +571,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
let events = thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
thread.send(new_prompt_id(), ["abc"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -684,7 +680,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
let events = thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
thread.send(new_prompt_id(), ["abc"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -718,7 +714,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), vec!["ghi"], cx)
|
thread.send(new_prompt_id(), vec!["ghi"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -818,7 +814,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(DelayTool);
|
thread.add_tool(DelayTool);
|
||||||
thread.send(
|
thread.send(
|
||||||
UserMessageId::new(),
|
new_prompt_id(),
|
||||||
[
|
[
|
||||||
"Call the delay tool twice in the same message.",
|
"Call the delay tool twice in the same message.",
|
||||||
"Once with 100ms. Once with 300ms.",
|
"Once with 100ms. Once with 300ms.",
|
||||||
|
@ -898,7 +894,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_profile(AgentProfileId("test-1".into()));
|
thread.set_profile(AgentProfileId("test-1".into()));
|
||||||
thread.send(UserMessageId::new(), ["test"], cx)
|
thread.send(new_prompt_id(), ["test"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -918,7 +914,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_profile(AgentProfileId("test-2".into()));
|
thread.set_profile(AgentProfileId("test-2".into()));
|
||||||
thread.send(UserMessageId::new(), ["test2"], cx)
|
thread.send(new_prompt_id(), ["test2"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -986,7 +982,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
|
|
||||||
let events = thread.update(cx, |thread, cx| {
|
let events = thread.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
|
thread.send(new_prompt_id(), ["Hey"], cx).unwrap()
|
||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
@ -1028,7 +1024,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
|
||||||
// Send again after adding the echo tool, ensuring the name collision is resolved.
|
// Send again after adding the echo tool, ensuring the name collision is resolved.
|
||||||
let events = thread.update(cx, |thread, cx| {
|
let events = thread.update(cx, |thread, cx| {
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
|
thread.send(new_prompt_id(), ["Go"], cx).unwrap()
|
||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let completion = fake_model.pending_completions().pop().unwrap();
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
@ -1235,9 +1231,7 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| thread.send(new_prompt_id(), ["Go"], cx))
|
||||||
thread.send(UserMessageId::new(), ["Go"], cx)
|
|
||||||
})
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
let completion = fake_model.pending_completions().pop().unwrap();
|
let completion = fake_model.pending_completions().pop().unwrap();
|
||||||
|
@ -1271,7 +1265,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
thread.add_tool(InfiniteTool);
|
thread.add_tool(InfiniteTool);
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(
|
thread.send(
|
||||||
UserMessageId::new(),
|
new_prompt_id(),
|
||||||
["Call the echo tool, then call the infinite tool, then explain their output"],
|
["Call the echo tool, then call the infinite tool, then explain their output"],
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -1327,7 +1321,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
let events = thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(
|
thread.send(
|
||||||
UserMessageId::new(),
|
new_prompt_id(),
|
||||||
["Testing: reply with 'Hello' then stop."],
|
["Testing: reply with 'Hello' then stop."],
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -1353,7 +1347,7 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let events_1 = thread
|
let events_1 = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
thread.send(new_prompt_id(), ["Hello 1"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -1362,7 +1356,7 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let events_2 = thread
|
let events_2 = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
thread.send(new_prompt_id(), ["Hello 2"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -1384,7 +1378,7 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let events_1 = thread
|
let events_1 = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
thread.send(new_prompt_id(), ["Hello 1"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -1396,7 +1390,7 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let events_2 = thread
|
let events_2 = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
thread.send(new_prompt_id(), ["Hello 2"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -1416,9 +1410,7 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let events = thread
|
let events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx))
|
||||||
thread.send(UserMessageId::new(), ["Hello"], cx)
|
|
||||||
})
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
|
@ -1464,7 +1456,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let message_id = UserMessageId::new();
|
let message_id = new_prompt_id();
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(message_id.clone(), ["Hello"], cx)
|
thread.send(message_id.clone(), ["Hello"], cx)
|
||||||
|
@ -1516,7 +1508,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| thread.truncate(message_id, cx))
|
.update(cx, |thread, cx| thread.rewind(message_id, cx))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
thread.read_with(cx, |thread, _| {
|
thread.read_with(cx, |thread, _| {
|
||||||
|
@ -1526,9 +1518,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
// Ensure we can still send a new message after truncation.
|
// Ensure we can still send a new message after truncation.
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hi"], cx))
|
||||||
thread.send(UserMessageId::new(), ["Hi"], cx)
|
|
||||||
})
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
thread.update(cx, |thread, _cx| {
|
thread.update(cx, |thread, _cx| {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1582,7 +1572,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Message 1"], cx)
|
thread.send(new_prompt_id(), ["Message 1"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -1625,7 +1615,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
assert_first_message_state(cx);
|
assert_first_message_state(cx);
|
||||||
|
|
||||||
let second_message_id = UserMessageId::new();
|
let second_message_id = new_prompt_id();
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(second_message_id.clone(), ["Message 2"], cx)
|
thread.send(second_message_id.clone(), ["Message 2"], cx)
|
||||||
|
@ -1677,7 +1667,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
|
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| thread.truncate(second_message_id, cx))
|
.update(cx, |thread, cx| thread.rewind(second_message_id, cx))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
@ -1696,9 +1686,7 @@ async fn test_title_generation(cx: &mut TestAppContext) {
|
||||||
});
|
});
|
||||||
|
|
||||||
let send = thread
|
let send = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx))
|
||||||
thread.send(UserMessageId::new(), ["Hello"], cx)
|
|
||||||
})
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
@ -1719,7 +1707,7 @@ async fn test_title_generation(cx: &mut TestAppContext) {
|
||||||
// Send another message, ensuring no title is generated this time.
|
// Send another message, ensuring no title is generated this time.
|
||||||
let send = thread
|
let send = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(UserMessageId::new(), ["Hello again"], cx)
|
thread.send(new_prompt_id(), ["Hello again"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -1740,7 +1728,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.add_tool(ToolRequiringPermission);
|
thread.add_tool(ToolRequiringPermission);
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(UserMessageId::new(), ["Hey!"], cx)
|
thread.send(new_prompt_id(), ["Hey!"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -1892,7 +1880,9 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||||
let model = model.as_fake();
|
let model = model.as_fake();
|
||||||
assert_eq!(model.id().0, "fake", "should return default model");
|
assert_eq!(model.id().0, "fake", "should return default model");
|
||||||
|
|
||||||
let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
|
let request = acp_thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(new_prompt_id(), vec!["abc".into()], cx)
|
||||||
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
model.send_last_completion_stream_text_chunk("def");
|
model.send_last_completion_stream_text_chunk("def");
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -1922,8 +1912,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||||
let result = cx
|
let result = cx
|
||||||
.update(|cx| {
|
.update(|cx| {
|
||||||
connection.prompt(
|
connection.prompt(
|
||||||
Some(acp_thread::UserMessageId::new()),
|
|
||||||
acp::PromptRequest {
|
acp::PromptRequest {
|
||||||
|
prompt_id: Some(new_prompt_id()),
|
||||||
session_id: session_id.clone(),
|
session_id: session_id.clone(),
|
||||||
prompt: vec!["ghi".into()],
|
prompt: vec!["ghi".into()],
|
||||||
},
|
},
|
||||||
|
@ -1946,9 +1936,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||||
let fake_model = model.as_fake();
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| thread.send(new_prompt_id(), ["Think"], cx))
|
||||||
thread.send(UserMessageId::new(), ["Think"], cx)
|
|
||||||
})
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
@ -2049,7 +2037,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
thread.send(new_prompt_id(), ["Hello!"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -2093,7 +2081,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
thread.send(new_prompt_id(), ["Hello!"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -2159,7 +2147,7 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||||
thread.add_tool(EchoTool);
|
thread.add_tool(EchoTool);
|
||||||
thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
|
thread.send(new_prompt_id(), ["Call the echo tool!"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -2234,7 +2222,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
||||||
let mut events = thread
|
let mut events = thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
|
||||||
thread.send(UserMessageId::new(), ["Hello!"], cx)
|
thread.send(new_prompt_id(), ["Hello!"], cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
||||||
ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
|
ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
|
||||||
Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
|
Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
|
||||||
};
|
};
|
||||||
use acp_thread::{MentionUri, UserMessageId};
|
use acp_thread::MentionUri;
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
|
use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
|
@ -137,7 +137,7 @@ impl Message {
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub struct UserMessage {
|
pub struct UserMessage {
|
||||||
pub id: UserMessageId,
|
pub id: acp::PromptId,
|
||||||
pub content: Vec<UserMessageContent>,
|
pub content: Vec<UserMessageContent>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -564,7 +564,7 @@ pub struct Thread {
|
||||||
pending_message: Option<AgentMessage>,
|
pending_message: Option<AgentMessage>,
|
||||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
tool_use_limit_reached: bool,
|
tool_use_limit_reached: bool,
|
||||||
request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
|
request_token_usage: HashMap<acp::PromptId, language_model::TokenUsage>,
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
cumulative_token_usage: TokenUsage,
|
cumulative_token_usage: TokenUsage,
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
|
@ -1070,7 +1070,7 @@ impl Thread {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
|
pub fn rewind(&mut self, message_id: acp::PromptId, cx: &mut Context<Self>) -> Result<()> {
|
||||||
self.cancel(cx);
|
self.cancel(cx);
|
||||||
let Some(position) = self.messages.iter().position(
|
let Some(position) = self.messages.iter().position(
|
||||||
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
|
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
|
||||||
|
@ -1118,7 +1118,7 @@ impl Thread {
|
||||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||||
pub fn send<T>(
|
pub fn send<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
id: UserMessageId,
|
id: acp::PromptId,
|
||||||
content: impl IntoIterator<Item = T>,
|
content: impl IntoIterator<Item = T>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
|
) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
|
||||||
|
|
|
@ -29,6 +29,7 @@ pub struct AcpConnection {
|
||||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||||
auth_methods: Vec<acp::AuthMethod>,
|
auth_methods: Vec<acp::AuthMethod>,
|
||||||
prompt_capabilities: acp::PromptCapabilities,
|
prompt_capabilities: acp::PromptCapabilities,
|
||||||
|
supports_rewind: bool,
|
||||||
_io_task: Task<Result<()>>,
|
_io_task: Task<Result<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,6 +148,7 @@ impl AcpConnection {
|
||||||
server_name,
|
server_name,
|
||||||
sessions,
|
sessions,
|
||||||
prompt_capabilities: response.agent_capabilities.prompt_capabilities,
|
prompt_capabilities: response.agent_capabilities.prompt_capabilities,
|
||||||
|
supports_rewind: response.agent_capabilities.rewind_session,
|
||||||
_io_task: io_task,
|
_io_task: io_task,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -225,9 +227,22 @@ impl AgentConnection for AcpConnection {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rewind(
|
||||||
|
&self,
|
||||||
|
session_id: &agent_client_protocol::SessionId,
|
||||||
|
_cx: &App,
|
||||||
|
) -> Option<Rc<dyn acp_thread::AgentSessionRewind>> {
|
||||||
|
if !self.supports_rewind {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some(Rc::new(AcpRewinder {
|
||||||
|
connection: self.connection.clone(),
|
||||||
|
session_id: session_id.clone(),
|
||||||
|
}) as _)
|
||||||
|
}
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
_id: Option<acp_thread::UserMessageId>,
|
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
|
@ -302,6 +317,25 @@ impl AgentConnection for AcpConnection {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct AcpRewinder {
|
||||||
|
connection: Rc<acp::ClientSideConnection>,
|
||||||
|
session_id: acp::SessionId,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl acp_thread::AgentSessionRewind for AcpRewinder {
|
||||||
|
fn rewind(&self, prompt_id: acp::PromptId, cx: &mut App) -> Task<Result<()>> {
|
||||||
|
let conn = self.connection.clone();
|
||||||
|
let params = acp::RewindRequest {
|
||||||
|
session_id: self.session_id.clone(),
|
||||||
|
prompt_id,
|
||||||
|
};
|
||||||
|
cx.foreground_executor().spawn(async move {
|
||||||
|
conn.rewind(params).await?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ClientDelegate {
|
struct ClientDelegate {
|
||||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||||
cx: AsyncApp,
|
cx: AsyncApp,
|
||||||
|
|
|
@ -294,7 +294,6 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
_id: Option<acp_thread::UserMessageId>,
|
|
||||||
params: acp::PromptRequest,
|
params: acp::PromptRequest,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> Task<Result<acp::PromptResponse>> {
|
) -> Task<Result<acp::PromptResponse>> {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::AgentServer;
|
use crate::AgentServer;
|
||||||
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
|
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus, new_prompt_id};
|
||||||
use agent_client_protocol as acp;
|
use agent_client_protocol as acp;
|
||||||
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||||
use gpui::{AppContext, Entity, TestAppContext};
|
use gpui::{AppContext, Entity, TestAppContext};
|
||||||
|
@ -77,6 +77,7 @@ where
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.send(
|
thread.send(
|
||||||
|
new_prompt_id(),
|
||||||
vec![
|
vec![
|
||||||
acp::ContentBlock::Text(acp::TextContent {
|
acp::ContentBlock::Text(acp::TextContent {
|
||||||
text: "Read the file ".into(),
|
text: "Read the file ".into(),
|
||||||
|
|
|
@ -67,7 +67,7 @@ impl EntryViewState {
|
||||||
|
|
||||||
match thread_entry {
|
match thread_entry {
|
||||||
AgentThreadEntry::UserMessage(message) => {
|
AgentThreadEntry::UserMessage(message) => {
|
||||||
let has_id = message.id.is_some();
|
let has_id = message.prompt_id.is_some();
|
||||||
let chunks = message.chunks.clone();
|
let chunks = message.chunks.clone();
|
||||||
if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
|
if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
|
||||||
if !editor.focus_handle(cx).is_focused(window) {
|
if !editor.focus_handle(cx).is_focused(window) {
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use acp_thread::{
|
use acp_thread::{
|
||||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
|
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
|
||||||
AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent,
|
AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent,
|
||||||
ToolCallStatus, UserMessageId,
|
ToolCallStatus, new_prompt_id,
|
||||||
};
|
};
|
||||||
use acp_thread::{AgentConnection, Plan};
|
use acp_thread::{AgentConnection, Plan};
|
||||||
use action_log::ActionLog;
|
use action_log::ActionLog;
|
||||||
|
@ -791,7 +791,7 @@ impl AcpThreadView {
|
||||||
if let Some(thread) = self.thread()
|
if let Some(thread) = self.thread()
|
||||||
&& let Some(AgentThreadEntry::UserMessage(user_message)) =
|
&& let Some(AgentThreadEntry::UserMessage(user_message)) =
|
||||||
thread.read(cx).entries().get(event.entry_index)
|
thread.read(cx).entries().get(event.entry_index)
|
||||||
&& user_message.id.is_some()
|
&& user_message.prompt_id.is_some()
|
||||||
{
|
{
|
||||||
self.editing_message = Some(event.entry_index);
|
self.editing_message = Some(event.entry_index);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -801,7 +801,7 @@ impl AcpThreadView {
|
||||||
if let Some(thread) = self.thread()
|
if let Some(thread) = self.thread()
|
||||||
&& let Some(AgentThreadEntry::UserMessage(user_message)) =
|
&& let Some(AgentThreadEntry::UserMessage(user_message)) =
|
||||||
thread.read(cx).entries().get(event.entry_index)
|
thread.read(cx).entries().get(event.entry_index)
|
||||||
&& user_message.id.is_some()
|
&& user_message.prompt_id.is_some()
|
||||||
{
|
{
|
||||||
if editor.read(cx).text(cx).as_str() == user_message.content.to_markdown(cx) {
|
if editor.read(cx).text(cx).as_str() == user_message.content.to_markdown(cx) {
|
||||||
self.editing_message = None;
|
self.editing_message = None;
|
||||||
|
@ -942,7 +942,7 @@ impl AcpThreadView {
|
||||||
|
|
||||||
telemetry::event!("Agent Message Sent", agent = agent_telemetry_id);
|
telemetry::event!("Agent Message Sent", agent = agent_telemetry_id);
|
||||||
|
|
||||||
thread.send(contents, cx)
|
thread.send(new_prompt_id(), contents, cx)
|
||||||
})?;
|
})?;
|
||||||
send.await
|
send.await
|
||||||
});
|
});
|
||||||
|
@ -1011,7 +1011,12 @@ impl AcpThreadView {
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(user_message_id) = thread.update(cx, |thread, _| {
|
let Some(user_message_id) = thread.update(cx, |thread, _| {
|
||||||
thread.entries().get(entry_ix)?.user_message()?.id.clone()
|
thread
|
||||||
|
.entries()
|
||||||
|
.get(entry_ix)?
|
||||||
|
.user_message()?
|
||||||
|
.prompt_id
|
||||||
|
.clone()
|
||||||
}) else {
|
}) else {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
@ -1316,7 +1321,7 @@ impl AcpThreadView {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rewind(&mut self, message_id: &UserMessageId, cx: &mut Context<Self>) {
|
fn rewind(&mut self, message_id: &acp::PromptId, cx: &mut Context<Self>) {
|
||||||
let Some(thread) = self.thread() else {
|
let Some(thread) = self.thread() else {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
@ -1379,7 +1384,7 @@ impl AcpThreadView {
|
||||||
.gap_1p5()
|
.gap_1p5()
|
||||||
.w_full()
|
.w_full()
|
||||||
.children(rules_item)
|
.children(rules_item)
|
||||||
.children(message.id.clone().and_then(|message_id| {
|
.children(message.prompt_id.clone().and_then(|message_id| {
|
||||||
message.checkpoint.as_ref()?.show.then(|| {
|
message.checkpoint.as_ref()?.show.then(|| {
|
||||||
h_flex()
|
h_flex()
|
||||||
.px_3()
|
.px_3()
|
||||||
|
@ -1416,7 +1421,7 @@ impl AcpThreadView {
|
||||||
.map(|this|{
|
.map(|this|{
|
||||||
if editing && editor_focus {
|
if editing && editor_focus {
|
||||||
this.border_color(focus_border)
|
this.border_color(focus_border)
|
||||||
} else if message.id.is_some() {
|
} else if message.prompt_id.is_some() {
|
||||||
this.hover(|s| s.border_color(focus_border.opacity(0.8)))
|
this.hover(|s| s.border_color(focus_border.opacity(0.8)))
|
||||||
} else {
|
} else {
|
||||||
this
|
this
|
||||||
|
@ -1437,7 +1442,7 @@ impl AcpThreadView {
|
||||||
.bg(cx.theme().colors().editor_background)
|
.bg(cx.theme().colors().editor_background)
|
||||||
.overflow_hidden();
|
.overflow_hidden();
|
||||||
|
|
||||||
if message.id.is_some() {
|
if message.prompt_id.is_some() {
|
||||||
this.child(
|
this.child(
|
||||||
base_container
|
base_container
|
||||||
.child(
|
.child(
|
||||||
|
@ -5398,7 +5403,6 @@ pub(crate) mod tests {
|
||||||
|
|
||||||
fn prompt(
|
fn prompt(
|
||||||
&self,
|
&self,
|
||||||
_id: Option<acp_thread::UserMessageId>,
|
|
||||||
_params: acp::PromptRequest,
|
_params: acp::PromptRequest,
|
||||||
_cx: &mut App,
|
_cx: &mut App,
|
||||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||||
|
@ -5544,7 +5548,7 @@ pub(crate) mod tests {
|
||||||
let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else {
|
let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else {
|
||||||
panic!();
|
panic!();
|
||||||
};
|
};
|
||||||
user_message.id.clone().unwrap()
|
user_message.prompt_id.clone().unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
thread_view.read_with(cx, |view, cx| {
|
thread_view.read_with(cx, |view, cx| {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue