Lay the groundwork to support history in agent2 (#36483)

This pull request introduces title generation and history replaying. We
still need to wire up the rest of the history but this gets us very
close. I extracted a lot of this code from `agent2-history` because that
branch was starting to get long-lived and there were lots of changes
since we started.

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2025-08-19 16:24:23 +02:00 committed by GitHub
parent c4083b9b63
commit 6c255c1973
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 929 additions and 328 deletions

3
Cargo.lock generated
View file

@ -191,6 +191,7 @@ version = "0.1.0"
dependencies = [
"acp_thread",
"action_log",
"agent",
"agent-client-protocol",
"agent_servers",
"agent_settings",
@ -208,6 +209,7 @@ dependencies = [
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"git",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
@ -256,6 +258,7 @@ name = "agent_servers"
version = "0.1.0"
dependencies = [
"acp_thread",
"action_log",
"agent-client-protocol",
"agent_settings",
"agentic-coding-protocol",

View file

@ -537,9 +537,15 @@ impl ToolCallContent {
acp::ToolCallContent::Content { content } => {
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
}
acp::ToolCallContent::Diff { diff } => {
Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
}
acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
Diff::finalized(
diff.path,
diff.old_text,
diff.new_text,
language_registry,
cx,
)
})),
}
}
@ -682,6 +688,7 @@ pub struct AcpThread {
#[derive(Debug)]
pub enum AcpThreadEvent {
NewEntry,
TitleUpdated,
EntryUpdated(usize),
EntriesRemoved(Range<usize>),
ToolAuthorizationRequired,
@ -728,11 +735,9 @@ impl AcpThread {
title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
session_id: acp::SessionId,
cx: &mut Context<Self>,
) -> Self {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self {
action_log,
shared_buffers: Default::default(),
@ -926,6 +931,12 @@ impl AcpThread {
cx.emit(AcpThreadEvent::NewEntry);
}
pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
self.title = title;
cx.emit(AcpThreadEvent::TitleUpdated);
Ok(())
}
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::Retry(status));
}
@ -1657,7 +1668,7 @@ mod tests {
use super::*;
use anyhow::anyhow;
use futures::{channel::mpsc, future::LocalBoxFuture, select};
use gpui::{AsyncApp, TestAppContext, WeakEntity};
use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
use project::{FakeFs, Fs};
use rand::Rng as _;
@ -2327,7 +2338,7 @@ mod tests {
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut gpui::App,
cx: &mut App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(
rand::thread_rng()
@ -2337,8 +2348,16 @@ mod tests {
.collect::<String>()
.into(),
);
let thread =
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
)
});
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}

View file

@ -5,11 +5,12 @@ use collections::IndexMap;
use gpui::{Entity, SharedString, Task};
use language_model::LanguageModelProviderId;
use project::Project;
use serde::{Deserialize, Serialize};
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserMessageId(Arc<str>);
impl UserMessageId {
@ -208,6 +209,7 @@ impl AgentModelList {
mod test_support {
use std::sync::Arc;
use action_log::ActionLog;
use collections::HashMap;
use futures::{channel::oneshot, future::try_join_all};
use gpui::{AppContext as _, WeakEntity};
@ -295,8 +297,16 @@ mod test_support {
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let thread =
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
)
});
self.sessions.lock().insert(
session_id,
Session {

View file

@ -1,4 +1,3 @@
use agent_client_protocol as acp;
use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey};
@ -21,17 +20,13 @@ pub enum Diff {
}
impl Diff {
pub fn from_acp(
diff: acp::Diff,
pub fn finalized(
path: PathBuf,
old_text: Option<String>,
new_text: String,
language_registry: Arc<LanguageRegistry>,
cx: &mut Context<Self>,
) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));

View file

@ -2,6 +2,7 @@ use agent::ThreadId;
use anyhow::{Context as _, Result, bail};
use file_icons::FileIcons;
use prompt_store::{PromptId, UserPromptId};
use serde::{Deserialize, Serialize};
use std::{
fmt,
ops::Range,
@ -11,7 +12,7 @@ use std::{
use ui::{App, IconName, SharedString};
use url::Url;
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum MentionUri {
File {
abs_path: PathBuf,

View file

@ -14,6 +14,7 @@ workspace = true
[dependencies]
acp_thread.workspace = true
action_log.workspace = true
agent.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
agent_settings.workspace = true
@ -26,6 +27,7 @@ collections.workspace = true
context_server.workspace = true
fs.workspace = true
futures.workspace = true
git.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
@ -59,6 +61,7 @@ which.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
agent = { workspace = true, "features" = ["test-support"] }
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
@ -66,6 +69,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
git = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }

View file

@ -1,10 +1,11 @@
use crate::{
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
UserMessageContent, WebSearchTool, templates::Templates,
};
use acp_thread::AgentModelSelector;
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
@ -427,18 +428,19 @@ impl NativeAgent {
) {
self.models.refresh_list(cx);
let default_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|m| m.model.clone());
let registry = LanguageModelRegistry::read_global(cx);
let default_model = registry.default_model().map(|m| m.model.clone());
let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| {
if thread.model().is_none()
&& let Some(model) = default_model.clone()
{
thread.set_model(model);
thread.set_model(model, cx);
cx.notify();
}
thread.set_summarization_model(summarization_model.clone(), cx);
});
}
}
@ -462,10 +464,7 @@ impl NativeAgentConnection {
session_id: acp::SessionId,
cx: &mut App,
f: impl 'static
+ FnOnce(
Entity<Thread>,
&mut App,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
+ FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent
@ -489,7 +488,18 @@ impl NativeAgentConnection {
log::trace!("Received completion event: {:?}", event);
match event {
AgentResponseEvent::Text(text) => {
ThreadEvent::UserMessage(message) => {
acp_thread.update(cx, |thread, cx| {
for content in message.content {
thread.push_user_content_block(
Some(message.id.clone()),
content.into(),
cx,
);
}
})?;
}
ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@ -501,7 +511,7 @@ impl NativeAgentConnection {
)
})?;
}
AgentResponseEvent::Thinking(text) => {
ThreadEvent::AgentThinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@ -513,7 +523,7 @@ impl NativeAgentConnection {
)
})?;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
@ -536,22 +546,26 @@ impl NativeAgentConnection {
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
ThreadEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})??;
}
AgentResponseEvent::ToolCallUpdate(update) => {
ThreadEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
AgentResponseEvent::Retry(status) => {
ThreadEvent::TitleUpdate(title) => {
acp_thread
.update(cx, |thread, cx| thread.update_title(title, cx))??;
}
ThreadEvent::Retry(status) => {
acp_thread.update(cx, |thread, cx| {
thread.update_retry_status(status, cx)
})?;
}
AgentResponseEvent::Stop(stop_reason) => {
ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
@ -604,8 +618,8 @@ impl AgentModelSelector for NativeAgentConnection {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
thread.update(cx, |thread, _cx| {
thread.set_model(model.clone());
thread.update(cx, |thread, cx| {
thread.set_model(model.clone(), cx);
});
update_settings_file::<AgentSettings>(
@ -665,30 +679,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context");
// Generate session ID
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
session_id.clone(),
cx,
)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
// Create Thread
let thread = agent.update(
cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx);
let language_registry = project.read(cx).languages().clone();
// Log available models for debugging
let available_count = registry.available_models(cx).count();
@ -699,6 +697,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.models
.model_from_id(&LanguageModels::model_id(&default_model.model))
});
let summarization_model = registry.thread_summary_model().map(|c| c.model);
let thread = cx.new(|cx| {
let mut thread = Thread::new(
@ -708,13 +707,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
action_log.clone(),
agent.templates.clone(),
default_model,
summarization_model,
cx,
);
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(EditFileTool::new(cx.entity()));
thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));
@ -722,7 +722,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
thread.add_tool(MovePathTool::new(project.clone()));
thread.add_tool(NowTool);
thread.add_tool(OpenTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
thread.add_tool(TerminalTool::new(project.clone(), cx));
thread.add_tool(ThinkingTool);
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
@ -733,6 +733,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
},
)??;
let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|_cx| {
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
action_log.clone(),
session_id.clone(),
)
})
})?;
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
@ -803,7 +818,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) {
agent.thread.update(cx, |thread, _cx| thread.cancel());
agent.thread.update(cx, |thread, cx| thread.cancel(cx));
}
});
}
@ -830,7 +845,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
Task::ready(
self.0
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
)
}
}

View file

@ -345,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await {
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message
let message = thread.last_message().unwrap();
@ -735,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
);
}
async fn expect_tool_call(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCall {
async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
match event {
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
ThreadEvent::ToolCall(tool_call) => return tool_call,
event => {
panic!("Unexpected event {event:?}");
}
@ -752,7 +750,7 @@ async fn expect_tool_call(
}
async fn expect_tool_call_update_fields(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> acp::ToolCallUpdate {
let event = events
.next()
@ -760,7 +758,7 @@ async fn expect_tool_call_update_fields(
.expect("no tool call authorization event received")
.unwrap();
match event {
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
return update;
}
event => {
@ -770,7 +768,7 @@ async fn expect_tool_call_update_fields(
}
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> ToolCallAuthorization {
loop {
let event = events
@ -778,7 +776,7 @@ async fn next_tool_call_authorization(
.await
.expect("no tool call authorization event received")
.unwrap();
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization
.options
.iter()
@ -945,13 +943,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
let mut echo_completed = false;
while let Some(event) = events.next().await {
match event.unwrap() {
AgentResponseEvent::ToolCall(tool_call) => {
ThreadEvent::ToolCall(tool_call) => {
assert_eq!(tool_call.title, expected_tools.remove(0));
if tool_call.title == "Echo" {
echo_id = Some(tool_call.id);
}
}
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
acp::ToolCallUpdate {
id,
fields:
@ -973,13 +971,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running.
thread.update(cx, |thread, _cx| thread.cancel());
thread.update(cx, |thread, cx| thread.cancel(cx));
let events = events.collect::<Vec<_>>().await;
let last_event = events.last();
assert!(
matches!(
last_event,
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
),
"unexpected event {last_event:?}"
);
@ -1161,7 +1159,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
thread
.update(cx, |thread, _cx| thread.truncate(message_id))
.update(cx, |thread, cx| thread.truncate(message_id, cx))
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
@ -1203,6 +1201,51 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
}
#[gpui::test]
async fn test_title_generation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let summary_model = Arc::new(FakeLanguageModel::default());
thread.update(cx, |thread, cx| {
thread.set_summarization_model(Some(summary_model.clone()), cx)
});
let send = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
// Ensure the summary model has been invoked to generate a title.
summary_model.send_last_completion_stream_text_chunk("Hello ");
summary_model.send_last_completion_stream_text_chunk("world\nG");
summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
summary_model.end_last_completion_stream();
send.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
// Send another message, ensuring no title is generated this time.
let send = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello again"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey again!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
assert_eq!(summary_model.pending_completions(), Vec::new());
send.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
}
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init);
@ -1442,7 +1485,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
@ -1454,10 +1497,10 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
AgentResponseEvent::Retry(retry_status) => {
ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
AgentResponseEvent::Stop(..) => break,
ThreadEvent::Stop(..) => break,
_ => {}
}
}
@ -1486,7 +1529,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
@ -1507,10 +1550,10 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
AgentResponseEvent::Retry(retry_status) => {
ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
AgentResponseEvent::Stop(..) => break,
ThreadEvent::Stop(..) => break,
_ => {}
}
}
@ -1543,7 +1586,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let mut events = thread
.update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn);
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
@ -1565,10 +1608,10 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let mut retry_events = Vec::new();
while let Some(event) = events.next().await {
match event {
Ok(AgentResponseEvent::Retry(retry_status)) => {
Ok(ThreadEvent::Retry(retry_status)) => {
retry_events.push(retry_status);
}
Ok(AgentResponseEvent::Stop(..)) => break,
Ok(ThreadEvent::Stop(..)) => break,
Err(error) => errors.push(error),
_ => {}
}
@ -1592,11 +1635,11 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
}
/// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
ThreadEvent::Stop(stop_reason) => Some(stop_reason),
_ => None,
})
.collect()
@ -1713,6 +1756,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
action_log,
templates,
Some(model.clone()),
None,
cx,
)
});

File diff suppressed because it is too large Load diff

View file

@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
})
})
}
fn replay(
&self,
_input: serde_json::Value,
_output: serde_json::Value,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
Ok(())
}
}

View file

@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
use language::{LanguageRegistry, ToPoint};
use language_model::LanguageModelToolResultContent;
use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
@ -98,11 +98,13 @@ pub enum EditFileMode {
#[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput {
#[serde(alias = "original_path")]
input_path: PathBuf,
project_path: PathBuf,
new_text: String,
old_text: Arc<String>,
#[serde(default)]
diff: String,
#[serde(alias = "raw_output")]
edit_agent_output: EditAgentOutput,
}
@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
}
pub struct EditFileTool {
thread: Entity<Thread>,
thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
}
impl EditFileTool {
pub fn new(thread: Entity<Thread>) -> Self {
Self { thread }
pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
Self {
thread,
language_registry,
}
}
fn authorize(
@ -167,8 +173,11 @@ impl EditFileTool {
// Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize
let thread = self.thread.read(cx);
let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
thread.project().read(cx).find_project_path(&input.path, cx)
}) else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
// If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary.
@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let project = self.thread.read(cx).project().clone();
let Ok(project) = self
.thread
.read_with(cx, |thread, _cx| thread.project().clone())
else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))),
@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
});
}
let Some(request) = self.thread.update(cx, |thread, cx| {
thread
.build_completion_request(CompletionIntent::ToolResults, cx)
.ok()
}) else {
return Task::ready(Err(anyhow!("Failed to build completion request")));
};
let thread = self.thread.read(cx);
let Some(model) = thread.model().cloned() else {
return Task::ready(Err(anyhow!("No language model configured")));
};
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx);
cx.spawn(async move |cx: &mut AsyncApp| {
authorize.await?;
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
(request, thread.model().cloned(), thread.action_log().clone())
})?;
let request = request?;
let model = model.context("No language model configured")?;
let edit_format = EditFormat::from_model(model.clone())?;
let edit_agent = EditAgent::new(
model,
@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
Ok(EditFileToolOutput {
input_path: input.path,
project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text,
diff: unified_diff,
@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
})
})
}
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()> {
event_stream.update_diff(cx.new(|cx| {
Diff::finalized(
output.input_path,
Some(output.old_text.to_string()),
output.new_text,
self.language_registry.clone(),
cx,
)
}));
Ok(())
}
}
/// Validate that the file path is valid, meaning:
@ -515,6 +541,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -527,6 +554,7 @@ mod tests {
action_log,
Templates::new(),
Some(model),
None,
cx,
)
});
@ -537,7 +565,11 @@ mod tests {
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
})
.await;
assert_eq!(
@ -724,6 +756,7 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
@ -750,9 +783,10 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
Arc::new(EditFileTool::new(
thread.downgrade(),
language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx)
});
@ -806,7 +840,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the unformatted content
@ -850,6 +888,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
@ -860,6 +899,7 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
@ -887,9 +927,10 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
Arc::new(EditFileTool::new(
thread.downgrade(),
language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx)
});
@ -938,10 +979,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool {
thread: thread.clone(),
})
.run(input, ToolCallEventStream::test().0, cx)
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
});
// Stream the content with trailing whitespace
@ -976,6 +1018,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
@ -986,10 +1029,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation
@ -1111,6 +1155,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@ -1123,10 +1168,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
@ -1220,7 +1266,7 @@ mod tests {
cx,
)
.await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1233,10 +1279,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test files in different worktrees
let test_cases = vec![
@ -1302,6 +1349,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1314,10 +1362,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test edge cases
let test_cases = vec![
@ -1386,6 +1435,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1398,10 +1448,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test different EditFileMode values
let modes = vec![
@ -1467,6 +1518,7 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1479,10 +1531,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
None,
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
assert_eq!(
tool.initial_title(Err(json!({

View file

@ -319,7 +319,7 @@ mod tests {
use theme::ThemeSettings;
use util::test::TempTree;
use crate::AgentResponseEvent;
use crate::ThreadEvent;
use super::*;
@ -396,7 +396,7 @@ mod tests {
});
cx.run_until_parked();
let event = stream_rx.try_next();
if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event {
if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
auth.response.send(auth.options[0].id.clone()).unwrap();
}

View file

@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
}
};
let result_text = if response.results.len() == 1 {
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
emit_update(&response, &event_stream);
Ok(WebSearchToolOutput(response))
})
}
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
emit_update(&output.0, &event_stream);
Ok(())
}
}
fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
let result_text = if response.results.len() == 1 {
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
}

View file

@ -18,6 +18,7 @@ doctest = false
[dependencies]
acp_thread.workspace = true
action_log.workspace = true
agent-client-protocol.workspace = true
agent_settings.workspace = true
agentic-coding-protocol.workspace = true

View file

@ -1,4 +1,5 @@
// Translates old acp agents into the new schema
use action_log::ActionLog;
use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
use anyhow::{Context as _, Result, anyhow};
@ -443,7 +444,8 @@ impl AgentConnection for AcpConnection {
cx.update(|cx| {
let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into());
AcpThread::new(self.name, self.clone(), project, session_id, cx)
let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(self.name, self.clone(), project, action_log, session_id)
});
current_thread.replace(thread.downgrade());
thread

View file

@ -1,3 +1,4 @@
use action_log::ActionLog;
use agent_client_protocol::{self as acp, Agent as _};
use anyhow::anyhow;
use collections::HashMap;
@ -153,14 +154,14 @@ impl AgentConnection for AcpConnection {
})?;
let session_id = response.session_id;
let thread = cx.new(|cx| {
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
action_log,
session_id.clone(),
cx,
)
})?;

View file

@ -1,6 +1,7 @@
mod mcp_server;
pub mod tools;
use action_log::ActionLog;
use collections::HashMap;
use context_server::listener::McpServerTool;
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
@ -215,8 +216,15 @@ impl AgentConnection for ClaudeAgentConnection {
}
});
let thread = cx.new(|cx| {
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|_cx| {
AcpThread::new(
"Claude Code",
self.clone(),
project,
action_log,
session_id.clone(),
)
})?;
thread_tx.send(thread.downgrade())?;

View file

@ -303,8 +303,13 @@ impl AcpThreadView {
let action_log_subscription =
cx.observe(&action_log, |_, _, cx| cx.notify());
this.list_state
.splice(0..0, thread.read(cx).entries().len());
let count = thread.read(cx).entries().len();
this.list_state.splice(0..0, count);
this.entry_view_state.update(cx, |view_state, cx| {
for ix in 0..count {
view_state.sync_entry(ix, &thread, window, cx);
}
});
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
@ -808,6 +813,7 @@ impl AcpThreadView {
self.thread_retry_status.take();
self.thread_state = ThreadState::ServerExited { status: *status };
}
AcpThreadEvent::TitleUpdated => {}
}
cx.notify();
}
@ -2816,12 +2822,15 @@ impl AcpThreadView {
return;
};
thread.update(cx, |thread, _cx| {
thread.update(cx, |thread, cx| {
let current_mode = thread.completion_mode();
thread.set_completion_mode(match current_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
});
thread.set_completion_mode(
match current_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
},
cx,
);
});
}
@ -3572,8 +3581,9 @@ impl AcpThreadView {
))
.on_click({
cx.listener(move |this, _, _window, cx| {
thread.update(cx, |thread, _cx| {
thread.set_completion_mode(CompletionMode::Burn);
thread.update(cx, |thread, cx| {
thread
.set_completion_mode(CompletionMode::Burn, cx);
});
this.resume_chat(cx);
})
@ -4156,12 +4166,13 @@ pub(crate) mod tests {
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
Task::ready(Ok(cx.new(|cx| {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(
"SaboteurAgentConnection",
self,
project,
action_log,
SessionId("test".into()),
cx,
)
})))
}

View file

@ -199,24 +199,21 @@ impl AgentDiffPane {
let action_log = thread.action_log(cx).clone();
let mut this = Self {
_subscriptions: [
Some(
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
this.update_excerpts(window, cx)
}),
),
_subscriptions: vec![
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
this.update_excerpts(window, cx)
}),
match &thread {
AgentDiffThread::Native(thread) => {
Some(cx.subscribe(thread, |this, _thread, event, cx| {
this.handle_thread_event(event, cx)
}))
}
AgentDiffThread::AcpThread(_) => None,
AgentDiffThread::Native(thread) => cx
.subscribe(thread, |this, _thread, event, cx| {
this.handle_native_thread_event(event, cx)
}),
AgentDiffThread::AcpThread(thread) => cx
.subscribe(thread, |this, _thread, event, cx| {
this.handle_acp_thread_event(event, cx)
}),
},
]
.into_iter()
.flatten()
.collect(),
],
title: SharedString::default(),
multibuffer,
editor,
@ -324,13 +321,20 @@ impl AgentDiffPane {
}
}
fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
match event {
ThreadEvent::SummaryGenerated => self.update_title(cx),
_ => {}
}
}
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
match event {
AcpThreadEvent::TitleUpdated => self.update_title(cx),
_ => {}
}
}
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
self.editor.update(cx, |editor, cx| {
@ -1523,7 +1527,8 @@ impl AgentDiff {
AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {
self.update_reviewing_editors(workspace, window, cx);
}
AcpThreadEvent::EntriesRemoved(_)
AcpThreadEvent::TitleUpdated
| AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Retry(_) => {}
}