Lay the groundwork to support history in agent2 (#36483)

This pull request introduces title generation and history replaying. We
still need to wire up the rest of the history but this gets us very
close. I extracted a lot of this code from `agent2-history` because that
branch was starting to get long-lived and there were lots of changes
since we started.

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2025-08-19 16:24:23 +02:00 committed by GitHub
parent c4083b9b63
commit 6c255c1973
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 929 additions and 328 deletions

3
Cargo.lock generated
View file

@ -191,6 +191,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"acp_thread", "acp_thread",
"action_log", "action_log",
"agent",
"agent-client-protocol", "agent-client-protocol",
"agent_servers", "agent_servers",
"agent_settings", "agent_settings",
@ -208,6 +209,7 @@ dependencies = [
"env_logger 0.11.8", "env_logger 0.11.8",
"fs", "fs",
"futures 0.3.31", "futures 0.3.31",
"git",
"gpui", "gpui",
"gpui_tokio", "gpui_tokio",
"handlebars 4.5.0", "handlebars 4.5.0",
@ -256,6 +258,7 @@ name = "agent_servers"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"acp_thread", "acp_thread",
"action_log",
"agent-client-protocol", "agent-client-protocol",
"agent_settings", "agent_settings",
"agentic-coding-protocol", "agentic-coding-protocol",

View file

@ -537,9 +537,15 @@ impl ToolCallContent {
acp::ToolCallContent::Content { content } => { acp::ToolCallContent::Content { content } => {
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx)) Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
} }
acp::ToolCallContent::Diff { diff } => { acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx))) Diff::finalized(
} diff.path,
diff.old_text,
diff.new_text,
language_registry,
cx,
)
})),
} }
} }
@ -682,6 +688,7 @@ pub struct AcpThread {
#[derive(Debug)] #[derive(Debug)]
pub enum AcpThreadEvent { pub enum AcpThreadEvent {
NewEntry, NewEntry,
TitleUpdated,
EntryUpdated(usize), EntryUpdated(usize),
EntriesRemoved(Range<usize>), EntriesRemoved(Range<usize>),
ToolAuthorizationRequired, ToolAuthorizationRequired,
@ -728,11 +735,9 @@ impl AcpThread {
title: impl Into<SharedString>, title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>, connection: Rc<dyn AgentConnection>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>,
session_id: acp::SessionId, session_id: acp::SessionId,
cx: &mut Context<Self>,
) -> Self { ) -> Self {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self { Self {
action_log, action_log,
shared_buffers: Default::default(), shared_buffers: Default::default(),
@ -926,6 +931,12 @@ impl AcpThread {
cx.emit(AcpThreadEvent::NewEntry); cx.emit(AcpThreadEvent::NewEntry);
} }
pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
self.title = title;
cx.emit(AcpThreadEvent::TitleUpdated);
Ok(())
}
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) { pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::Retry(status)); cx.emit(AcpThreadEvent::Retry(status));
} }
@ -1657,7 +1668,7 @@ mod tests {
use super::*; use super::*;
use anyhow::anyhow; use anyhow::anyhow;
use futures::{channel::mpsc, future::LocalBoxFuture, select}; use futures::{channel::mpsc, future::LocalBoxFuture, select};
use gpui::{AsyncApp, TestAppContext, WeakEntity}; use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc; use indoc::indoc;
use project::{FakeFs, Fs}; use project::{FakeFs, Fs};
use rand::Rng as _; use rand::Rng as _;
@ -2327,7 +2338,7 @@ mod tests {
self: Rc<Self>, self: Rc<Self>,
project: Entity<Project>, project: Entity<Project>,
_cwd: &Path, _cwd: &Path,
cx: &mut gpui::App, cx: &mut App,
) -> Task<gpui::Result<Entity<AcpThread>>> { ) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId( let session_id = acp::SessionId(
rand::thread_rng() rand::thread_rng()
@ -2337,8 +2348,16 @@ mod tests {
.collect::<String>() .collect::<String>()
.into(), .into(),
); );
let thread = let action_log = cx.new(|_| ActionLog::new(project.clone()));
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); let thread = cx.new(|_cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
)
});
self.sessions.lock().insert(session_id, thread.downgrade()); self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread)) Task::ready(Ok(thread))
} }

View file

@ -5,11 +5,12 @@ use collections::IndexMap;
use gpui::{Entity, SharedString, Task}; use gpui::{Entity, SharedString, Task};
use language_model::LanguageModelProviderId; use language_model::LanguageModelProviderId;
use project::Project; use project::Project;
use serde::{Deserialize, Serialize};
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName}; use ui::{App, IconName};
use uuid::Uuid; use uuid::Uuid;
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserMessageId(Arc<str>); pub struct UserMessageId(Arc<str>);
impl UserMessageId { impl UserMessageId {
@ -208,6 +209,7 @@ impl AgentModelList {
mod test_support { mod test_support {
use std::sync::Arc; use std::sync::Arc;
use action_log::ActionLog;
use collections::HashMap; use collections::HashMap;
use futures::{channel::oneshot, future::try_join_all}; use futures::{channel::oneshot, future::try_join_all};
use gpui::{AppContext as _, WeakEntity}; use gpui::{AppContext as _, WeakEntity};
@ -295,8 +297,16 @@ mod test_support {
cx: &mut gpui::App, cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> { ) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into()); let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let thread = let action_log = cx.new(|_| ActionLog::new(project.clone()));
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx)); let thread = cx.new(|_cx| {
AcpThread::new(
"Test",
self.clone(),
project,
action_log,
session_id.clone(),
)
});
self.sessions.lock().insert( self.sessions.lock().insert(
session_id, session_id,
Session { Session {

View file

@ -1,4 +1,3 @@
use agent_client_protocol as acp;
use anyhow::Result; use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey}; use editor::{MultiBuffer, PathKey};
@ -21,17 +20,13 @@ pub enum Diff {
} }
impl Diff { impl Diff {
pub fn from_acp( pub fn finalized(
diff: acp::Diff, path: PathBuf,
old_text: Option<String>,
new_text: String,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));

View file

@ -2,6 +2,7 @@ use agent::ThreadId;
use anyhow::{Context as _, Result, bail}; use anyhow::{Context as _, Result, bail};
use file_icons::FileIcons; use file_icons::FileIcons;
use prompt_store::{PromptId, UserPromptId}; use prompt_store::{PromptId, UserPromptId};
use serde::{Deserialize, Serialize};
use std::{ use std::{
fmt, fmt,
ops::Range, ops::Range,
@ -11,7 +12,7 @@ use std::{
use ui::{App, IconName, SharedString}; use ui::{App, IconName, SharedString};
use url::Url; use url::Url;
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum MentionUri { pub enum MentionUri {
File { File {
abs_path: PathBuf, abs_path: PathBuf,

View file

@ -14,6 +14,7 @@ workspace = true
[dependencies] [dependencies]
acp_thread.workspace = true acp_thread.workspace = true
action_log.workspace = true action_log.workspace = true
agent.workspace = true
agent-client-protocol.workspace = true agent-client-protocol.workspace = true
agent_servers.workspace = true agent_servers.workspace = true
agent_settings.workspace = true agent_settings.workspace = true
@ -26,6 +27,7 @@ collections.workspace = true
context_server.workspace = true context_server.workspace = true
fs.workspace = true fs.workspace = true
futures.workspace = true futures.workspace = true
git.workspace = true
gpui.workspace = true gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] } handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true html_to_markdown.workspace = true
@ -59,6 +61,7 @@ which.workspace = true
workspace-hack.workspace = true workspace-hack.workspace = true
[dev-dependencies] [dev-dependencies]
agent = { workspace = true, "features" = ["test-support"] }
ctor.workspace = true ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] } client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] } clock = { workspace = true, "features" = ["test-support"] }
@ -66,6 +69,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] } fs = { workspace = true, "features" = ["test-support"] }
git = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] } language = { workspace = true, "features" = ["test-support"] }

View file

@ -1,10 +1,11 @@
use crate::{ use crate::{
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, UserMessageContent, WebSearchTool, templates::Templates,
}; };
use acp_thread::AgentModelSelector; use acp_thread::AgentModelSelector;
use action_log::ActionLog;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::AgentSettings; use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
@ -427,18 +428,19 @@ impl NativeAgent {
) { ) {
self.models.refresh_list(cx); self.models.refresh_list(cx);
let default_model = LanguageModelRegistry::read_global(cx) let registry = LanguageModelRegistry::read_global(cx);
.default_model() let default_model = registry.default_model().map(|m| m.model.clone());
.map(|m| m.model.clone()); let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
for session in self.sessions.values_mut() { for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| { session.thread.update(cx, |thread, cx| {
if thread.model().is_none() if thread.model().is_none()
&& let Some(model) = default_model.clone() && let Some(model) = default_model.clone()
{ {
thread.set_model(model); thread.set_model(model, cx);
cx.notify(); cx.notify();
} }
thread.set_summarization_model(summarization_model.clone(), cx);
}); });
} }
} }
@ -462,10 +464,7 @@ impl NativeAgentConnection {
session_id: acp::SessionId, session_id: acp::SessionId,
cx: &mut App, cx: &mut App,
f: impl 'static f: impl 'static
+ FnOnce( + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
Entity<Thread>,
&mut App,
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
) -> Task<Result<acp::PromptResponse>> { ) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent agent
@ -489,7 +488,18 @@ impl NativeAgentConnection {
log::trace!("Received completion event: {:?}", event); log::trace!("Received completion event: {:?}", event);
match event { match event {
AgentResponseEvent::Text(text) => { ThreadEvent::UserMessage(message) => {
acp_thread.update(cx, |thread, cx| {
for content in message.content {
thread.push_user_content_block(
Some(message.id.clone()),
content.into(),
cx,
);
}
})?;
}
ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block( thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent { acp::ContentBlock::Text(acp::TextContent {
@ -501,7 +511,7 @@ impl NativeAgentConnection {
) )
})?; })?;
} }
AgentResponseEvent::Thinking(text) => { ThreadEvent::AgentThinking(text) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block( thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent { acp::ContentBlock::Text(acp::TextContent {
@ -513,7 +523,7 @@ impl NativeAgentConnection {
) )
})?; })?;
} }
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call, tool_call,
options, options,
response, response,
@ -536,22 +546,26 @@ impl NativeAgentConnection {
}) })
.detach(); .detach();
} }
AgentResponseEvent::ToolCall(tool_call) => { ThreadEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx) thread.upsert_tool_call(tool_call, cx)
})??; })??;
} }
AgentResponseEvent::ToolCallUpdate(update) => { ThreadEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx) thread.update_tool_call(update, cx)
})??; })??;
} }
AgentResponseEvent::Retry(status) => { ThreadEvent::TitleUpdate(title) => {
acp_thread
.update(cx, |thread, cx| thread.update_title(title, cx))??;
}
ThreadEvent::Retry(status) => {
acp_thread.update(cx, |thread, cx| { acp_thread.update(cx, |thread, cx| {
thread.update_retry_status(status, cx) thread.update_retry_status(status, cx)
})?; })?;
} }
AgentResponseEvent::Stop(stop_reason) => { ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason); log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason }); return Ok(acp::PromptResponse { stop_reason });
} }
@ -604,8 +618,8 @@ impl AgentModelSelector for NativeAgentConnection {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
}; };
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, cx| {
thread.set_model(model.clone()); thread.set_model(model.clone(), cx);
}); });
update_settings_file::<AgentSettings>( update_settings_file::<AgentSettings>(
@ -665,30 +679,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context"); log::debug!("Starting thread creation in async context");
// Generate session ID let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
session_id.clone(),
cx,
)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
// Create Thread // Create Thread
let thread = agent.update( let thread = agent.update(
cx, cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> { |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings // Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx); let registry = LanguageModelRegistry::read_global(cx);
let language_registry = project.read(cx).languages().clone();
// Log available models for debugging // Log available models for debugging
let available_count = registry.available_models(cx).count(); let available_count = registry.available_models(cx).count();
@ -699,6 +697,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.models .models
.model_from_id(&LanguageModels::model_id(&default_model.model)) .model_from_id(&LanguageModels::model_id(&default_model.model))
}); });
let summarization_model = registry.thread_summary_model().map(|c| c.model);
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let mut thread = Thread::new( let mut thread = Thread::new(
@ -708,13 +707,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
action_log.clone(), action_log.clone(),
agent.templates.clone(), agent.templates.clone(),
default_model, default_model,
summarization_model,
cx, cx,
); );
thread.add_tool(CopyPathTool::new(project.clone())); thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(CreateDirectoryTool::new(project.clone())); thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone())); thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone())); thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(EditFileTool::new(cx.entity())); thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client())); thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone())); thread.add_tool(GrepTool::new(project.clone()));
@ -722,7 +722,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
thread.add_tool(MovePathTool::new(project.clone())); thread.add_tool(MovePathTool::new(project.clone()));
thread.add_tool(NowTool); thread.add_tool(NowTool);
thread.add_tool(OpenTool::new(project.clone())); thread.add_tool(OpenTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log)); thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
thread.add_tool(TerminalTool::new(project.clone(), cx)); thread.add_tool(TerminalTool::new(project.clone(), cx));
thread.add_tool(ThinkingTool); thread.add_tool(ThinkingTool);
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model. thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
@ -733,6 +733,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}, },
)??; )??;
let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|_cx| {
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
action_log.clone(),
session_id.clone(),
)
})
})?;
// Store the session // Store the session
agent.update(cx, |agent, cx| { agent.update(cx, |agent, cx| {
agent.sessions.insert( agent.sessions.insert(
@ -803,7 +818,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Cancelling on session: {}", session_id); log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| { self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) { if let Some(agent) = agent.sessions.get(session_id) {
agent.thread.update(cx, |thread, _cx| thread.cancel()); agent.thread.update(cx, |thread, cx| thread.cancel(cx));
} }
}); });
} }
@ -830,7 +845,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor { impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> { fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id))) Task::ready(
self.0
.update(cx, |thread, cx| thread.truncate(message_id, cx)),
)
} }
} }

View file

@ -345,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let mut saw_partial_tool_use = false; let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message // Look for a tool use in the thread's last message
let message = thread.last_message().unwrap(); let message = thread.last_message().unwrap();
@ -735,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
); );
} }
async fn expect_tool_call( async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCall {
let event = events let event = events
.next() .next()
.await .await
.expect("no tool call authorization event received") .expect("no tool call authorization event received")
.unwrap(); .unwrap();
match event { match event {
AgentResponseEvent::ToolCall(tool_call) => return tool_call, ThreadEvent::ToolCall(tool_call) => return tool_call,
event => { event => {
panic!("Unexpected event {event:?}"); panic!("Unexpected event {event:?}");
} }
@ -752,7 +750,7 @@ async fn expect_tool_call(
} }
async fn expect_tool_call_update_fields( async fn expect_tool_call_update_fields(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>, events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> acp::ToolCallUpdate { ) -> acp::ToolCallUpdate {
let event = events let event = events
.next() .next()
@ -760,7 +758,7 @@ async fn expect_tool_call_update_fields(
.expect("no tool call authorization event received") .expect("no tool call authorization event received")
.unwrap(); .unwrap();
match event { match event {
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
return update; return update;
} }
event => { event => {
@ -770,7 +768,7 @@ async fn expect_tool_call_update_fields(
} }
async fn next_tool_call_authorization( async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>, events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> ToolCallAuthorization { ) -> ToolCallAuthorization {
loop { loop {
let event = events let event = events
@ -778,7 +776,7 @@ async fn next_tool_call_authorization(
.await .await
.expect("no tool call authorization event received") .expect("no tool call authorization event received")
.unwrap(); .unwrap();
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization let permission_kinds = tool_call_authorization
.options .options
.iter() .iter()
@ -945,13 +943,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
let mut echo_completed = false; let mut echo_completed = false;
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
match event.unwrap() { match event.unwrap() {
AgentResponseEvent::ToolCall(tool_call) => { ThreadEvent::ToolCall(tool_call) => {
assert_eq!(tool_call.title, expected_tools.remove(0)); assert_eq!(tool_call.title, expected_tools.remove(0));
if tool_call.title == "Echo" { if tool_call.title == "Echo" {
echo_id = Some(tool_call.id); echo_id = Some(tool_call.id);
} }
} }
AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
acp::ToolCallUpdate { acp::ToolCallUpdate {
id, id,
fields: fields:
@ -973,13 +971,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even // Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running. // if one of the tools is still running.
thread.update(cx, |thread, _cx| thread.cancel()); thread.update(cx, |thread, cx| thread.cancel(cx));
let events = events.collect::<Vec<_>>().await; let events = events.collect::<Vec<_>>().await;
let last_event = events.last(); let last_event = events.last();
assert!( assert!(
matches!( matches!(
last_event, last_event,
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
), ),
"unexpected event {last_event:?}" "unexpected event {last_event:?}"
); );
@ -1161,7 +1159,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
}); });
thread thread
.update(cx, |thread, _cx| thread.truncate(message_id)) .update(cx, |thread, cx| thread.truncate(message_id, cx))
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
@ -1203,6 +1201,51 @@ async fn test_truncate(cx: &mut TestAppContext) {
}); });
} }
#[gpui::test]
async fn test_title_generation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let summary_model = Arc::new(FakeLanguageModel::default());
thread.update(cx, |thread, cx| {
thread.set_summarization_model(Some(summary_model.clone()), cx)
});
let send = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
// Ensure the summary model has been invoked to generate a title.
summary_model.send_last_completion_stream_text_chunk("Hello ");
summary_model.send_last_completion_stream_text_chunk("world\nG");
summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
summary_model.end_last_completion_stream();
send.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
// Send another message, ensuring no title is generated this time.
let send = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello again"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey again!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
assert_eq!(summary_model.pending_completions(), Vec::new());
send.collect::<Vec<_>>().await;
thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
}
#[gpui::test] #[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) { async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init); cx.update(settings::init);
@ -1442,7 +1485,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn); thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx) thread.send(UserMessageId::new(), ["Hello!"], cx)
}) })
.unwrap(); .unwrap();
@ -1454,10 +1497,10 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let mut retry_events = Vec::new(); let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await { while let Some(Ok(event)) = events.next().await {
match event { match event {
AgentResponseEvent::Retry(retry_status) => { ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status); retry_events.push(retry_status);
} }
AgentResponseEvent::Stop(..) => break, ThreadEvent::Stop(..) => break,
_ => {} _ => {}
} }
} }
@ -1486,7 +1529,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn); thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx) thread.send(UserMessageId::new(), ["Hello!"], cx)
}) })
.unwrap(); .unwrap();
@ -1507,10 +1550,10 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let mut retry_events = Vec::new(); let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await { while let Some(Ok(event)) = events.next().await {
match event { match event {
AgentResponseEvent::Retry(retry_status) => { ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status); retry_events.push(retry_status);
} }
AgentResponseEvent::Stop(..) => break, ThreadEvent::Stop(..) => break,
_ => {} _ => {}
} }
} }
@ -1543,7 +1586,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn); thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx) thread.send(UserMessageId::new(), ["Hello!"], cx)
}) })
.unwrap(); .unwrap();
@ -1565,10 +1608,10 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let mut retry_events = Vec::new(); let mut retry_events = Vec::new();
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
match event { match event {
Ok(AgentResponseEvent::Retry(retry_status)) => { Ok(ThreadEvent::Retry(retry_status)) => {
retry_events.push(retry_status); retry_events.push(retry_status);
} }
Ok(AgentResponseEvent::Stop(..)) => break, Ok(ThreadEvent::Stop(..)) => break,
Err(error) => errors.push(error), Err(error) => errors.push(error),
_ => {} _ => {}
} }
@ -1592,11 +1635,11 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
} }
/// Filters out the stop events for asserting against in tests /// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> { fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
result_events result_events
.into_iter() .into_iter()
.filter_map(|event| match event.unwrap() { .filter_map(|event| match event.unwrap() {
AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), ThreadEvent::Stop(stop_reason) => Some(stop_reason),
_ => None, _ => None,
}) })
.collect() .collect()
@ -1713,6 +1756,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
action_log, action_log,
templates, templates,
Some(model.clone()), Some(model.clone()),
None,
cx, cx,
) )
}); });

File diff suppressed because it is too large Load diff

View file

@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
}) })
}) })
} }
fn replay(
&self,
_input: serde_json::Value,
_output: serde_json::Value,
_event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
Ok(())
}
} }

View file

@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent; use cloud_llm_client::CompletionIntent;
use collections::HashSet; use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task}; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc; use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave}; use language::language_settings::{self, FormatOnSave};
use language::{LanguageRegistry, ToPoint};
use language_model::LanguageModelToolResultContent; use language_model::LanguageModelToolResultContent;
use paths; use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget}; use project::lsp_store::{FormatTrigger, LspFormatTarget};
@ -98,11 +98,13 @@ pub enum EditFileMode {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput { pub struct EditFileToolOutput {
#[serde(alias = "original_path")]
input_path: PathBuf, input_path: PathBuf,
project_path: PathBuf,
new_text: String, new_text: String,
old_text: Arc<String>, old_text: Arc<String>,
#[serde(default)]
diff: String, diff: String,
#[serde(alias = "raw_output")]
edit_agent_output: EditAgentOutput, edit_agent_output: EditAgentOutput,
} }
@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
} }
pub struct EditFileTool { pub struct EditFileTool {
thread: Entity<Thread>, thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
} }
impl EditFileTool { impl EditFileTool {
pub fn new(thread: Entity<Thread>) -> Self { pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
Self { thread } Self {
thread,
language_registry,
}
} }
fn authorize( fn authorize(
@ -167,8 +173,11 @@ impl EditFileTool {
// Check if path is inside the global config directory // Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize // First check if it's already inside project - if not, try to canonicalize
let thread = self.thread.read(cx); let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
let project_path = thread.project().read(cx).find_project_path(&input.path, cx); thread.project().read(cx).find_project_path(&input.path, cx)
}) else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
// If the path is inside the project, and it's not one of the above edge cases, // If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary. // then no confirmation is necessary. Otherwise, confirmation is necessary.
@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<Self::Output>> { ) -> Task<Result<Self::Output>> {
let project = self.thread.read(cx).project().clone(); let Ok(project) = self
.thread
.read_with(cx, |thread, _cx| thread.project().clone())
else {
return Task::ready(Err(anyhow!("thread was dropped")));
};
let project_path = match resolve_path(&input, project.clone(), cx) { let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path, Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))), Err(err) => return Task::ready(Err(anyhow!(err))),
@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
}); });
} }
let Some(request) = self.thread.update(cx, |thread, cx| {
thread
.build_completion_request(CompletionIntent::ToolResults, cx)
.ok()
}) else {
return Task::ready(Err(anyhow!("Failed to build completion request")));
};
let thread = self.thread.read(cx);
let Some(model) = thread.model().cloned() else {
return Task::ready(Err(anyhow!("No language model configured")));
};
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx); let authorize = self.authorize(&input, &event_stream, cx);
cx.spawn(async move |cx: &mut AsyncApp| { cx.spawn(async move |cx: &mut AsyncApp| {
authorize.await?; authorize.await?;
let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
(request, thread.model().cloned(), thread.action_log().clone())
})?;
let request = request?;
let model = model.context("No language model configured")?;
let edit_format = EditFormat::from_model(model.clone())?; let edit_format = EditFormat::from_model(model.clone())?;
let edit_agent = EditAgent::new( let edit_agent = EditAgent::new(
model, model,
@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
Ok(EditFileToolOutput { Ok(EditFileToolOutput {
input_path: input.path, input_path: input.path,
project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(), new_text: new_text.clone(),
old_text, old_text,
diff: unified_diff, diff: unified_diff,
@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
}) })
}) })
} }
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()> {
event_stream.update_diff(cx.new(|cx| {
Diff::finalized(
output.input_path,
Some(output.old_text.to_string()),
output.new_text,
self.language_registry.clone(),
cx,
)
}));
Ok(())
}
} }
/// Validate that the file path is valid, meaning: /// Validate that the file path is valid, meaning:
@ -515,6 +541,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor()); let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await; fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -527,6 +554,7 @@ mod tests {
action_log, action_log,
Templates::new(), Templates::new(),
Some(model), Some(model),
None,
cx, cx,
) )
}); });
@ -537,7 +565,11 @@ mod tests {
path: "root/nonexistent_file.txt".into(), path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit, mode: EditFileMode::Edit,
}; };
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
}) })
.await; .await;
assert_eq!( assert_eq!(
@ -724,6 +756,7 @@ mod tests {
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
Some(model.clone()), Some(model.clone()),
None,
cx, cx,
) )
}); });
@ -750,9 +783,10 @@ mod tests {
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}; };
Arc::new(EditFileTool { Arc::new(EditFileTool::new(
thread: thread.clone(), thread.downgrade(),
}) language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx) .run(input, ToolCallEventStream::test().0, cx)
}); });
@ -806,7 +840,11 @@ mod tests {
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}; };
Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
}); });
// Stream the unformatted content // Stream the unformatted content
@ -850,6 +888,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
@ -860,6 +899,7 @@ mod tests {
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
Some(model.clone()), Some(model.clone()),
None,
cx, cx,
) )
}); });
@ -887,9 +927,10 @@ mod tests {
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}; };
Arc::new(EditFileTool { Arc::new(EditFileTool::new(
thread: thread.clone(), thread.downgrade(),
}) language_registry.clone(),
))
.run(input, ToolCallEventStream::test().0, cx) .run(input, ToolCallEventStream::test().0, cx)
}); });
@ -938,10 +979,11 @@ mod tests {
path: "root/src/main.rs".into(), path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite, mode: EditFileMode::Overwrite,
}; };
Arc::new(EditFileTool { Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
thread: thread.clone(), input,
}) ToolCallEventStream::test().0,
.run(input, ToolCallEventStream::test().0, cx) cx,
)
}); });
// Stream the content with trailing whitespace // Stream the content with trailing whitespace
@ -976,6 +1018,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default()); let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
@ -986,10 +1029,11 @@ mod tests {
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
Some(model.clone()), Some(model.clone()),
None,
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
fs.insert_tree("/root", json!({})).await; fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation // Test 1: Path with .zed component should require confirmation
@ -1111,6 +1155,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor()); let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await; fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
@ -1123,10 +1168,11 @@ mod tests {
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
Some(model.clone()), Some(model.clone()),
None,
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test global config paths - these should require confirmation if they exist and are outside the project // Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![ let test_cases = vec![
@ -1220,7 +1266,7 @@ mod tests {
cx, cx,
) )
.await; .await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1233,10 +1279,11 @@ mod tests {
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
Some(model.clone()), Some(model.clone()),
None,
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test files in different worktrees // Test files in different worktrees
let test_cases = vec![ let test_cases = vec![
@ -1302,6 +1349,7 @@ mod tests {
) )
.await; .await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1314,10 +1362,11 @@ mod tests {
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
Some(model.clone()), Some(model.clone()),
None,
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test edge cases // Test edge cases
let test_cases = vec![ let test_cases = vec![
@ -1386,6 +1435,7 @@ mod tests {
) )
.await; .await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1398,10 +1448,11 @@ mod tests {
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
Some(model.clone()), Some(model.clone()),
None,
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test different EditFileMode values // Test different EditFileMode values
let modes = vec![ let modes = vec![
@ -1467,6 +1518,7 @@ mod tests {
init_test(cx); init_test(cx);
let fs = project::FakeFs::new(cx.executor()); let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry = let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@ -1479,10 +1531,11 @@ mod tests {
action_log.clone(), action_log.clone(),
Templates::new(), Templates::new(),
Some(model.clone()), Some(model.clone()),
None,
cx, cx,
) )
}); });
let tool = Arc::new(EditFileTool { thread }); let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
assert_eq!( assert_eq!(
tool.initial_title(Err(json!({ tool.initial_title(Err(json!({

View file

@ -319,7 +319,7 @@ mod tests {
use theme::ThemeSettings; use theme::ThemeSettings;
use util::test::TempTree; use util::test::TempTree;
use crate::AgentResponseEvent; use crate::ThreadEvent;
use super::*; use super::*;
@ -396,7 +396,7 @@ mod tests {
}); });
cx.run_until_parked(); cx.run_until_parked();
let event = stream_rx.try_next(); let event = stream_rx.try_next();
if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event { if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
auth.response.send(auth.options[0].id.clone()).unwrap(); auth.response.send(auth.options[0].id.clone()).unwrap();
} }

View file

@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
} }
}; };
let result_text = if response.results.len() == 1 { emit_update(&response, &event_stream);
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
Ok(WebSearchToolOutput(response)) Ok(WebSearchToolOutput(response))
}) })
} }
fn replay(
&self,
_input: Self::Input,
output: Self::Output,
event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Result<()> {
emit_update(&output.0, &event_stream);
Ok(())
}
}
fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
let result_text = if response.results.len() == 1 {
"1 result".to_string()
} else {
format!("{} results", response.results.len())
};
event_stream.update_fields(acp::ToolCallUpdateFields {
title: Some(format!("Searched the web: {result_text}")),
content: Some(
response
.results
.iter()
.map(|result| acp::ToolCallContent::Content {
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
name: result.title.clone(),
uri: result.url.clone(),
title: Some(result.title.clone()),
description: Some(result.text.clone()),
mime_type: None,
annotations: None,
size: None,
}),
})
.collect(),
),
..Default::default()
});
} }

View file

@ -18,6 +18,7 @@ doctest = false
[dependencies] [dependencies]
acp_thread.workspace = true acp_thread.workspace = true
action_log.workspace = true
agent-client-protocol.workspace = true agent-client-protocol.workspace = true
agent_settings.workspace = true agent_settings.workspace = true
agentic-coding-protocol.workspace = true agentic-coding-protocol.workspace = true

View file

@ -1,4 +1,5 @@
// Translates old acp agents into the new schema // Translates old acp agents into the new schema
use action_log::ActionLog;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _}; use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
@ -443,7 +444,8 @@ impl AgentConnection for AcpConnection {
cx.update(|cx| { cx.update(|cx| {
let thread = cx.new(|cx| { let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into()); let session_id = acp::SessionId("acp-old-no-id".into());
AcpThread::new(self.name, self.clone(), project, session_id, cx) let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(self.name, self.clone(), project, action_log, session_id)
}); });
current_thread.replace(thread.downgrade()); current_thread.replace(thread.downgrade());
thread thread

View file

@ -1,3 +1,4 @@
use action_log::ActionLog;
use agent_client_protocol::{self as acp, Agent as _}; use agent_client_protocol::{self as acp, Agent as _};
use anyhow::anyhow; use anyhow::anyhow;
use collections::HashMap; use collections::HashMap;
@ -153,14 +154,14 @@ impl AgentConnection for AcpConnection {
})?; })?;
let session_id = response.session_id; let session_id = response.session_id;
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
let thread = cx.new(|cx| { let thread = cx.new(|_cx| {
AcpThread::new( AcpThread::new(
self.server_name, self.server_name,
self.clone(), self.clone(),
project, project,
action_log,
session_id.clone(), session_id.clone(),
cx,
) )
})?; })?;

View file

@ -1,6 +1,7 @@
mod mcp_server; mod mcp_server;
pub mod tools; pub mod tools;
use action_log::ActionLog;
use collections::HashMap; use collections::HashMap;
use context_server::listener::McpServerTool; use context_server::listener::McpServerTool;
use language_models::provider::anthropic::AnthropicLanguageModelProvider; use language_models::provider::anthropic::AnthropicLanguageModelProvider;
@ -215,8 +216,15 @@ impl AgentConnection for ClaudeAgentConnection {
} }
}); });
let thread = cx.new(|cx| { let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx) let thread = cx.new(|_cx| {
AcpThread::new(
"Claude Code",
self.clone(),
project,
action_log,
session_id.clone(),
)
})?; })?;
thread_tx.send(thread.downgrade())?; thread_tx.send(thread.downgrade())?;

View file

@ -303,8 +303,13 @@ impl AcpThreadView {
let action_log_subscription = let action_log_subscription =
cx.observe(&action_log, |_, _, cx| cx.notify()); cx.observe(&action_log, |_, _, cx| cx.notify());
this.list_state let count = thread.read(cx).entries().len();
.splice(0..0, thread.read(cx).entries().len()); this.list_state.splice(0..0, count);
this.entry_view_state.update(cx, |view_state, cx| {
for ix in 0..count {
view_state.sync_entry(ix, &thread, window, cx);
}
});
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx); AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
@ -808,6 +813,7 @@ impl AcpThreadView {
self.thread_retry_status.take(); self.thread_retry_status.take();
self.thread_state = ThreadState::ServerExited { status: *status }; self.thread_state = ThreadState::ServerExited { status: *status };
} }
AcpThreadEvent::TitleUpdated => {}
} }
cx.notify(); cx.notify();
} }
@ -2816,12 +2822,15 @@ impl AcpThreadView {
return; return;
}; };
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, cx| {
let current_mode = thread.completion_mode(); let current_mode = thread.completion_mode();
thread.set_completion_mode(match current_mode { thread.set_completion_mode(
CompletionMode::Burn => CompletionMode::Normal, match current_mode {
CompletionMode::Normal => CompletionMode::Burn, CompletionMode::Burn => CompletionMode::Normal,
}); CompletionMode::Normal => CompletionMode::Burn,
},
cx,
);
}); });
} }
@ -3572,8 +3581,9 @@ impl AcpThreadView {
)) ))
.on_click({ .on_click({
cx.listener(move |this, _, _window, cx| { cx.listener(move |this, _, _window, cx| {
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, cx| {
thread.set_completion_mode(CompletionMode::Burn); thread
.set_completion_mode(CompletionMode::Burn, cx);
}); });
this.resume_chat(cx); this.resume_chat(cx);
}) })
@ -4156,12 +4166,13 @@ pub(crate) mod tests {
cx: &mut gpui::App, cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> { ) -> Task<gpui::Result<Entity<AcpThread>>> {
Task::ready(Ok(cx.new(|cx| { Task::ready(Ok(cx.new(|cx| {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new( AcpThread::new(
"SaboteurAgentConnection", "SaboteurAgentConnection",
self, self,
project, project,
action_log,
SessionId("test".into()), SessionId("test".into()),
cx,
) )
}))) })))
} }

View file

@ -199,24 +199,21 @@ impl AgentDiffPane {
let action_log = thread.action_log(cx).clone(); let action_log = thread.action_log(cx).clone();
let mut this = Self { let mut this = Self {
_subscriptions: [ _subscriptions: vec![
Some( cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
cx.observe_in(&action_log, window, |this, _action_log, window, cx| { this.update_excerpts(window, cx)
this.update_excerpts(window, cx) }),
}),
),
match &thread { match &thread {
AgentDiffThread::Native(thread) => { AgentDiffThread::Native(thread) => cx
Some(cx.subscribe(thread, |this, _thread, event, cx| { .subscribe(thread, |this, _thread, event, cx| {
this.handle_thread_event(event, cx) this.handle_native_thread_event(event, cx)
})) }),
} AgentDiffThread::AcpThread(thread) => cx
AgentDiffThread::AcpThread(_) => None, .subscribe(thread, |this, _thread, event, cx| {
this.handle_acp_thread_event(event, cx)
}),
}, },
] ],
.into_iter()
.flatten()
.collect(),
title: SharedString::default(), title: SharedString::default(),
multibuffer, multibuffer,
editor, editor,
@ -324,13 +321,20 @@ impl AgentDiffPane {
} }
} }
fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) { fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
match event { match event {
ThreadEvent::SummaryGenerated => self.update_title(cx), ThreadEvent::SummaryGenerated => self.update_title(cx),
_ => {} _ => {}
} }
} }
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
match event {
AcpThreadEvent::TitleUpdated => self.update_title(cx),
_ => {}
}
}
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) { pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) { if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
self.editor.update(cx, |editor, cx| { self.editor.update(cx, |editor, cx| {
@ -1523,7 +1527,8 @@ impl AgentDiff {
AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => { AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {
self.update_reviewing_editors(workspace, window, cx); self.update_reviewing_editors(workspace, window, cx);
} }
AcpThreadEvent::EntriesRemoved(_) AcpThreadEvent::TitleUpdated
| AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::ToolAuthorizationRequired | AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Retry(_) => {} | AcpThreadEvent::Retry(_) => {}
} }