agent2: Port edit_file tool (#35844)

TODO:
- [x] Authorization
- [x] Restore tests

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Agus Zubiaga 2025-08-08 09:43:53 -03:00 committed by GitHub
parent d705585a2e
commit 2526dcb5a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 2075 additions and 414 deletions

4
Cargo.lock generated
View file

@ -158,8 +158,10 @@ dependencies = [
"acp_thread", "acp_thread",
"agent-client-protocol", "agent-client-protocol",
"agent_servers", "agent_servers",
"agent_settings",
"anyhow", "anyhow",
"assistant_tool", "assistant_tool",
"assistant_tools",
"client", "client",
"clock", "clock",
"cloud_llm_client", "cloud_llm_client",
@ -177,6 +179,8 @@ dependencies = [
"language_model", "language_model",
"language_models", "language_models",
"log", "log",
"lsp",
"paths",
"pretty_assertions", "pretty_assertions",
"project", "project",
"prompt_store", "prompt_store",

View file

@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
# #
agentic-coding-protocol = "0.0.10" agentic-coding-protocol = "0.0.10"
agent-client-protocol = { version = "0.0.23" } agent-client-protocol = "0.0.23"
aho-corasick = "1.1" aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14" any_vec = "0.14"

View file

@ -1,18 +1,17 @@
mod connection; mod connection;
mod diff;
pub use connection::*; pub use connection::*;
pub use diff::*;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{Context as _, Result}; use anyhow::{Context as _, Result};
use assistant_tool::ActionLog; use assistant_tool::ActionLog;
use buffer_diff::BufferDiff; use editor::Bias;
use editor::{Bias, MultiBuffer, PathKey};
use futures::{FutureExt, channel::oneshot, future::BoxFuture}; use futures::{FutureExt, channel::oneshot, future::BoxFuture};
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
use itertools::Itertools; use itertools::Itertools;
use language::{ use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
text_diff,
};
use markdown::Markdown; use markdown::Markdown;
use project::{AgentLocation, Project}; use project::{AgentLocation, Project};
use std::collections::HashMap; use std::collections::HashMap;
@ -140,7 +139,7 @@ impl AgentThreadEntry {
} }
} }
pub fn diffs(&self) -> impl Iterator<Item = &Diff> { pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
if let AgentThreadEntry::ToolCall(call) = self { if let AgentThreadEntry::ToolCall(call) = self {
itertools::Either::Left(call.diffs()) itertools::Either::Left(call.diffs())
} else { } else {
@ -249,7 +248,7 @@ impl ToolCall {
} }
} }
pub fn diffs(&self) -> impl Iterator<Item = &Diff> { pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
self.content.iter().filter_map(|content| match content { self.content.iter().filter_map(|content| match content {
ToolCallContent::ContentBlock { .. } => None, ToolCallContent::ContentBlock { .. } => None,
ToolCallContent::Diff { diff } => Some(diff), ToolCallContent::Diff { diff } => Some(diff),
@ -389,7 +388,7 @@ impl ContentBlock {
#[derive(Debug)] #[derive(Debug)]
pub enum ToolCallContent { pub enum ToolCallContent {
ContentBlock { content: ContentBlock }, ContentBlock { content: ContentBlock },
Diff { diff: Diff }, Diff { diff: Entity<Diff> },
} }
impl ToolCallContent { impl ToolCallContent {
@ -403,7 +402,7 @@ impl ToolCallContent {
content: ContentBlock::new(content, &language_registry, cx), content: ContentBlock::new(content, &language_registry, cx),
}, },
acp::ToolCallContent::Diff { diff } => Self::Diff { acp::ToolCallContent::Diff { diff } => Self::Diff {
diff: Diff::from_acp(diff, language_registry, cx), diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)),
}, },
} }
} }
@ -411,108 +410,11 @@ impl ToolCallContent {
pub fn to_markdown(&self, cx: &App) -> String { pub fn to_markdown(&self, cx: &App) -> String {
match self { match self {
Self::ContentBlock { content } => content.to_markdown(cx).to_string(), Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
Self::Diff { diff } => diff.to_markdown(cx), Self::Diff { diff } => diff.read(cx).to_markdown(cx),
} }
} }
} }
#[derive(Debug)]
pub struct Diff {
pub multibuffer: Entity<MultiBuffer>,
pub path: PathBuf,
_task: Task<Result<()>>,
}
impl Diff {
pub fn from_acp(
diff: acp::Diff,
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
let task = cx.spawn({
let multibuffer = multibuffer.clone();
let path = path.clone();
async move |cx| {
let language = language_registry
.language_for_file_path(&path)
.await
.log_err();
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
buffer.set_language(language, cx);
buffer.snapshot()
})?;
buffer_diff
.update(cx, |diff, cx| {
diff.set_base_text(
old_buffer_snapshot,
Some(language_registry),
new_buffer_snapshot,
cx,
)
})?
.await?;
multibuffer
.update(cx, |multibuffer, cx| {
let hunk_ranges = {
let buffer = new_buffer.read(cx);
let diff = buffer_diff.read(cx);
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>()
};
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&new_buffer, cx),
new_buffer.clone(),
hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff, cx);
})
.log_err();
anyhow::Ok(())
}
});
Self {
multibuffer,
path,
_task: task,
}
}
fn to_markdown(&self, cx: &App) -> String {
let buffer_text = self
.multibuffer
.read(cx)
.all_buffers()
.iter()
.map(|buffer| buffer.read(cx).text())
.join("\n");
format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
}
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Plan { pub struct Plan {
pub entries: Vec<PlanEntry>, pub entries: Vec<PlanEntry>,
@ -823,6 +725,21 @@ impl AcpThread {
Ok(()) Ok(())
} }
pub fn set_tool_call_diff(
&mut self,
tool_call_id: &acp::ToolCallId,
diff: Entity<Diff>,
cx: &mut Context<Self>,
) -> Result<()> {
let (ix, current_call) = self
.tool_call_mut(tool_call_id)
.context("Tool call not found")?;
current_call.content.clear();
current_call.content.push(ToolCallContent::Diff { diff });
cx.emit(AcpThreadEvent::EntryUpdated(ix));
Ok(())
}
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one. /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) { pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
let status = ToolCallStatus::Allowed { let status = ToolCallStatus::Allowed {

View file

@ -0,0 +1,388 @@
use agent_client_protocol as acp;
use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey};
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task};
use itertools::Itertools;
use language::{
Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _, Point, Rope, TextBuffer,
};
use std::{
cmp::Reverse,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
};
use util::ResultExt;
pub enum Diff {
Pending(PendingDiff),
Finalized(FinalizedDiff),
}
impl Diff {
pub fn from_acp(
diff: acp::Diff,
language_registry: Arc<LanguageRegistry>,
cx: &mut Context<Self>,
) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
let task = cx.spawn({
let multibuffer = multibuffer.clone();
let path = path.clone();
async move |_, cx| {
let language = language_registry
.language_for_file_path(&path)
.await
.log_err();
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
buffer.set_language(language, cx);
buffer.snapshot()
})?;
buffer_diff
.update(cx, |diff, cx| {
diff.set_base_text(
old_buffer_snapshot,
Some(language_registry),
new_buffer_snapshot,
cx,
)
})?
.await?;
multibuffer
.update(cx, |multibuffer, cx| {
let hunk_ranges = {
let buffer = new_buffer.read(cx);
let diff = buffer_diff.read(cx);
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>()
};
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&new_buffer, cx),
new_buffer.clone(),
hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff, cx);
})
.log_err();
anyhow::Ok(())
}
});
Self::Finalized(FinalizedDiff {
multibuffer,
path,
_update_diff: task,
})
}
pub fn new(buffer: Entity<Buffer>, cx: &mut Context<Self>) -> Self {
let buffer_snapshot = buffer.read(cx).snapshot();
let base_text = buffer_snapshot.text();
let language_registry = buffer.read(cx).language_registry();
let text_snapshot = buffer.read(cx).text_snapshot();
let buffer_diff = cx.new(|cx| {
let mut diff = BufferDiff::new(&text_snapshot, cx);
let _ = diff.set_base_text(
buffer_snapshot.clone(),
language_registry,
text_snapshot,
cx,
);
diff
});
let multibuffer = cx.new(|cx| {
let mut multibuffer = MultiBuffer::without_headers(Capability::ReadOnly);
multibuffer.add_diff(buffer_diff.clone(), cx);
multibuffer
});
Self::Pending(PendingDiff {
multibuffer,
base_text: Arc::new(base_text),
_subscription: cx.observe(&buffer, |this, _, cx| {
if let Diff::Pending(diff) = this {
diff.update(cx);
}
}),
buffer,
diff: buffer_diff,
revealed_ranges: Vec::new(),
update_diff: Task::ready(Ok(())),
})
}
pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Self>) {
if let Self::Pending(diff) = self {
diff.reveal_range(range, cx);
}
}
pub fn finalize(&mut self, cx: &mut Context<Self>) {
if let Self::Pending(diff) = self {
*self = Self::Finalized(diff.finalize(cx));
}
}
pub fn multibuffer(&self) -> &Entity<MultiBuffer> {
match self {
Self::Pending(PendingDiff { multibuffer, .. }) => multibuffer,
Self::Finalized(FinalizedDiff { multibuffer, .. }) => multibuffer,
}
}
pub fn to_markdown(&self, cx: &App) -> String {
let buffer_text = self
.multibuffer()
.read(cx)
.all_buffers()
.iter()
.map(|buffer| buffer.read(cx).text())
.join("\n");
let path = match self {
Diff::Pending(PendingDiff { buffer, .. }) => {
buffer.read(cx).file().map(|file| file.path().as_ref())
}
Diff::Finalized(FinalizedDiff { path, .. }) => Some(path.as_path()),
};
format!(
"Diff: {}\n```\n{}\n```\n",
path.unwrap_or(Path::new("untitled")).display(),
buffer_text
)
}
}
pub struct PendingDiff {
multibuffer: Entity<MultiBuffer>,
base_text: Arc<String>,
buffer: Entity<Buffer>,
diff: Entity<BufferDiff>,
revealed_ranges: Vec<Range<Anchor>>,
_subscription: Subscription,
update_diff: Task<Result<()>>,
}
impl PendingDiff {
pub fn update(&mut self, cx: &mut Context<Diff>) {
let buffer = self.buffer.clone();
let buffer_diff = self.diff.clone();
let base_text = self.base_text.clone();
self.update_diff = cx.spawn(async move |diff, cx| {
let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?;
let diff_snapshot = BufferDiff::update_diff(
buffer_diff.clone(),
text_snapshot.clone(),
Some(base_text),
false,
false,
None,
None,
cx,
)
.await?;
buffer_diff.update(cx, |diff, cx| {
diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
})?;
diff.update(cx, |diff, cx| {
if let Diff::Pending(diff) = diff {
diff.update_visible_ranges(cx);
}
})
});
}
pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Diff>) {
self.revealed_ranges.push(range);
self.update_visible_ranges(cx);
}
fn finalize(&self, cx: &mut Context<Diff>) -> FinalizedDiff {
let ranges = self.excerpt_ranges(cx);
let base_text = self.base_text.clone();
let language_registry = self.buffer.read(cx).language_registry().clone();
let path = self
.buffer
.read(cx)
.file()
.map(|file| file.path().as_ref())
.unwrap_or(Path::new("untitled"))
.into();
// Replace the buffer in the multibuffer with the snapshot
let buffer = cx.new(|cx| {
let language = self.buffer.read(cx).language().cloned();
let buffer = TextBuffer::new_normalized(
0,
cx.entity_id().as_non_zero_u64().into(),
self.buffer.read(cx).line_ending(),
self.buffer.read(cx).as_rope().clone(),
);
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
buffer.set_language(language, cx);
buffer
});
let buffer_diff = cx.spawn({
let buffer = buffer.clone();
let language_registry = language_registry.clone();
async move |_this, cx| {
build_buffer_diff(base_text, &buffer, language_registry, cx).await
}
});
let update_diff = cx.spawn(async move |this, cx| {
let buffer_diff = buffer_diff.await?;
this.update(cx, |this, cx| {
this.multibuffer().update(cx, |multibuffer, cx| {
let path_key = PathKey::for_buffer(&buffer, cx);
multibuffer.clear(cx);
multibuffer.set_excerpts_for_path(
path_key,
buffer,
ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff.clone(), cx);
});
cx.notify();
})
});
FinalizedDiff {
path,
multibuffer: self.multibuffer.clone(),
_update_diff: update_diff,
}
}
fn update_visible_ranges(&mut self, cx: &mut Context<Diff>) {
let ranges = self.excerpt_ranges(cx);
self.multibuffer.update(cx, |multibuffer, cx| {
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&self.buffer, cx),
self.buffer.clone(),
ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
let end = multibuffer.len(cx);
Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
});
cx.notify();
}
fn excerpt_ranges(&self, cx: &App) -> Vec<Range<Point>> {
let buffer = self.buffer.read(cx);
let diff = self.diff.read(cx);
let mut ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>();
ranges.extend(
self.revealed_ranges
.iter()
.map(|range| range.to_point(&buffer)),
);
ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
// Merge adjacent ranges
let mut ranges = ranges.into_iter().peekable();
let mut merged_ranges = Vec::new();
while let Some(mut range) = ranges.next() {
while let Some(next_range) = ranges.peek() {
if range.end >= next_range.start {
range.end = range.end.max(next_range.end);
ranges.next();
} else {
break;
}
}
merged_ranges.push(range);
}
merged_ranges
}
}
pub struct FinalizedDiff {
path: PathBuf,
multibuffer: Entity<MultiBuffer>,
_update_diff: Task<Result<()>>,
}
async fn build_buffer_diff(
old_text: Arc<String>,
buffer: &Entity<Buffer>,
language_registry: Option<Arc<LanguageRegistry>>,
cx: &mut AsyncApp,
) -> Result<Entity<BufferDiff>> {
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
let old_text_rope = cx
.background_spawn({
let old_text = old_text.clone();
async move { Rope::from(old_text.as_str()) }
})
.await;
let base_buffer = cx
.update(|cx| {
Buffer::build_snapshot(
old_text_rope,
buffer.language().cloned(),
language_registry,
cx,
)
})?
.await;
let diff_snapshot = cx
.update(|cx| {
BufferDiffSnapshot::new_with_base_buffer(
buffer.text.clone(),
Some(old_text),
base_buffer,
cx,
)
})?
.await;
let secondary_diff = cx.new(|cx| {
let mut diff = BufferDiff::new(&buffer, cx);
diff.set_snapshot(diff_snapshot.clone(), &buffer, cx);
diff
})?;
cx.new(|cx| {
let mut diff = BufferDiff::new(&buffer.text, cx);
diff.set_snapshot(diff_snapshot, &buffer, cx);
diff.set_secondary_diff(secondary_diff);
diff
})
}

View file

@ -15,8 +15,10 @@ workspace = true
acp_thread.workspace = true acp_thread.workspace = true
agent-client-protocol.workspace = true agent-client-protocol.workspace = true
agent_servers.workspace = true agent_servers.workspace = true
agent_settings.workspace = true
anyhow.workspace = true anyhow.workspace = true
assistant_tool.workspace = true assistant_tool.workspace = true
assistant_tools.workspace = true
cloud_llm_client.workspace = true cloud_llm_client.workspace = true
collections.workspace = true collections.workspace = true
fs.workspace = true fs.workspace = true
@ -29,6 +31,7 @@ language.workspace = true
language_model.workspace = true language_model.workspace = true
language_models.workspace = true language_models.workspace = true
log.workspace = true log.workspace = true
paths.workspace = true
project.workspace = true project.workspace = true
prompt_store.workspace = true prompt_store.workspace = true
rust-embed.workspace = true rust-embed.workspace = true
@ -53,6 +56,7 @@ gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] } language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] } language_model = { workspace = true, "features" = ["test-support"] }
lsp = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] } project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] } settings = { workspace = true, "features" = ["test-support"] }

View file

@ -1,5 +1,5 @@
use crate::{templates::Templates, AgentResponseEvent, Thread}; use crate::{templates::Templates, AgentResponseEvent, Thread};
use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization}; use crate::{EditFileTool, FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
use acp_thread::ModelSelector; use acp_thread::ModelSelector;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
@ -412,11 +412,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
anyhow!("No default model configured. Please configure a default model in settings.") anyhow!("No default model configured. Please configure a default model in settings.")
})?; })?;
let thread = cx.new(|_| { let thread = cx.new(|cx| {
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model); let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
thread.add_tool(ThinkingTool); thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log)); thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(EditFileTool::new(cx.entity()));
thread thread
}); });
@ -564,6 +565,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
) )
})??; })??;
} }
AgentResponseEvent::ToolCallDiff(tool_call_diff) => {
acp_thread.update(cx, |thread, cx| {
thread.set_tool_call_diff(
&tool_call_diff.tool_call_id,
tool_call_diff.diff,
cx,
)
})??;
}
AgentResponseEvent::Stop(stop_reason) => { AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason); log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason }); return Ok(acp::PromptResponse { stop_reason });

View file

@ -9,5 +9,6 @@ mod tests;
pub use agent::*; pub use agent::*;
pub use native_agent_server::NativeAgentServer; pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*; pub use thread::*;
pub use tools::*; pub use tools::*;

View file

@ -1,5 +1,4 @@
use super::*; use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection; use acp_thread::AgentConnection;
use agent_client_protocol::{self as acp}; use agent_client_protocol::{self as acp};
use anyhow::Result; use anyhow::Result;
@ -273,7 +272,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
tool_name: ToolRequiringPermission.name().into(), tool_name: ToolRequiringPermission.name().into(),
is_error: false, is_error: false,
content: "Allowed".into(), content: "Allowed".into(),
output: None output: Some("Allowed".into())
}), }),
MessageContent::ToolResult(LanguageModelToolResult { MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),

View file

@ -14,6 +14,7 @@ pub struct EchoTool;
impl AgentTool for EchoTool { impl AgentTool for EchoTool {
type Input = EchoToolInput; type Input = EchoToolInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"echo".into() "echo".into()
@ -48,6 +49,7 @@ pub struct DelayTool;
impl AgentTool for DelayTool { impl AgentTool for DelayTool {
type Input = DelayToolInput; type Input = DelayToolInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"delay".into() "delay".into()
@ -84,6 +86,7 @@ pub struct ToolRequiringPermission;
impl AgentTool for ToolRequiringPermission { impl AgentTool for ToolRequiringPermission {
type Input = ToolRequiringPermissionInput; type Input = ToolRequiringPermissionInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"tool_requiring_permission".into() "tool_requiring_permission".into()
@ -99,14 +102,11 @@ impl AgentTool for ToolRequiringPermission {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: Self::Input, _input: Self::Input,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> ) -> Task<Result<String>> {
where let auth_check = event_stream.authorize("Authorize?".into());
Self: Sized,
{
let auth_check = self.authorize(input, event_stream);
cx.foreground_executor().spawn(async move { cx.foreground_executor().spawn(async move {
auth_check.await?; auth_check.await?;
Ok("Allowed".to_string()) Ok("Allowed".to_string())
@ -121,6 +121,7 @@ pub struct InfiniteTool;
impl AgentTool for InfiniteTool { impl AgentTool for InfiniteTool {
type Input = InfiniteToolInput; type Input = InfiniteToolInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"infinite".into() "infinite".into()
@ -171,19 +172,20 @@ pub struct WordListTool;
impl AgentTool for WordListTool { impl AgentTool for WordListTool {
type Input = WordListInput; type Input = WordListInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"word_list".into() "word_list".into()
} }
fn initial_title(&self, _input: Self::Input) -> SharedString {
"List of random words".into()
}
fn kind(&self) -> acp::ToolKind { fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other acp::ToolKind::Other
} }
fn initial_title(&self, _input: Self::Input) -> SharedString {
"List of random words".into()
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
_input: Self::Input, _input: Self::Input,

View file

@ -1,4 +1,5 @@
use crate::templates::{SystemPromptTemplate, Template, Templates}; use crate::{SystemPromptTemplate, Template, Templates};
use acp_thread::Diff;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{adapt_schema_to_format, ActionLog}; use assistant_tool::{adapt_schema_to_format, ActionLog};
@ -103,6 +104,7 @@ pub enum AgentResponseEvent {
ToolCall(acp::ToolCall), ToolCall(acp::ToolCall),
ToolCallUpdate(acp::ToolCallUpdate), ToolCallUpdate(acp::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization), ToolCallAuthorization(ToolCallAuthorization),
ToolCallDiff(ToolCallDiff),
Stop(acp::StopReason), Stop(acp::StopReason),
} }
@ -113,6 +115,12 @@ pub struct ToolCallAuthorization {
pub response: oneshot::Sender<acp::PermissionOptionId>, pub response: oneshot::Sender<acp::PermissionOptionId>,
} }
#[derive(Debug)]
pub struct ToolCallDiff {
pub tool_call_id: acp::ToolCallId,
pub diff: Entity<acp_thread::Diff>,
}
pub struct Thread { pub struct Thread {
messages: Vec<AgentMessage>, messages: Vec<AgentMessage>,
completion_mode: CompletionMode, completion_mode: CompletionMode,
@ -125,12 +133,13 @@ pub struct Thread {
project_context: Rc<RefCell<ProjectContext>>, project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>, templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>, pub selected_model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
} }
impl Thread { impl Thread {
pub fn new( pub fn new(
_project: Entity<Project>, project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>, project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
templates: Arc<Templates>, templates: Arc<Templates>,
@ -145,10 +154,19 @@ impl Thread {
project_context, project_context,
templates, templates,
selected_model: default_model, selected_model: default_model,
project,
action_log, action_log,
} }
} }
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) { pub fn set_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode; self.completion_mode = mode;
} }
@ -315,10 +333,6 @@ impl Thread {
events_rx events_rx
} }
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn build_system_message(&self) -> AgentMessage { pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message"); log::debug!("Building system message");
let prompt = SystemPromptTemplate { let prompt = SystemPromptTemplate {
@ -490,15 +504,33 @@ impl Thread {
})); }));
}; };
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx); let tool_event_stream =
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
tool_event_stream.send_update(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
let supports_images = self.selected_model.supports_images();
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
Some(cx.foreground_executor().spawn(async move { Some(cx.foreground_executor().spawn(async move {
match tool_result.await { let tool_result = tool_result.await.and_then(|output| {
Ok(tool_output) => LanguageModelToolResult { if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
if !supports_images {
return Err(anyhow!(
"Attempted to read an image, but this model doesn't support it.",
));
}
}
Ok(output)
});
match tool_result {
Ok(output) => LanguageModelToolResult {
tool_use_id: tool_use.id, tool_use_id: tool_use.id,
tool_name: tool_use.name, tool_name: tool_use.name,
is_error: false, is_error: false,
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)), content: output.llm_output,
output: None, output: Some(output.raw_output),
}, },
Err(error) => LanguageModelToolResult { Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id, tool_use_id: tool_use.id,
@ -511,24 +543,6 @@ impl Thread {
})) }))
} }
fn run_tool(
&self,
tool: Arc<dyn AnyAgentTool>,
tool_use: LanguageModelToolUse,
event_stream: AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Task<Result<String>> {
cx.spawn(async move |_this, cx| {
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
tool_event_stream.send_update(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
.await
})
}
fn handle_tool_use_json_parse_error_event( fn handle_tool_use_json_parse_error_event(
&mut self, &mut self,
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
@ -572,7 +586,7 @@ impl Thread {
self.messages.last_mut().unwrap() self.messages.last_mut().unwrap()
} }
fn build_completion_request( pub(crate) fn build_completion_request(
&self, &self,
completion_intent: CompletionIntent, completion_intent: CompletionIntent,
cx: &mut App, cx: &mut App,
@ -662,6 +676,7 @@ where
Self: 'static + Sized, Self: 'static + Sized,
{ {
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema; type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
fn name(&self) -> SharedString; fn name(&self) -> SharedString;
@ -685,23 +700,13 @@ where
schemars::schema_for!(Self::Input) schemars::schema_for!(Self::Input)
} }
/// Allows the tool to authorize a given tool call with the user if necessary
fn authorize(
&self,
input: Self::Input,
event_stream: ToolCallEventStream,
) -> impl use<Self> + Future<Output = Result<()>> {
let json_input = serde_json::json!(&input);
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
}
/// Runs the tool with the provided input. /// Runs the tool with the provided input.
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: Self::Input, input: Self::Input,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>>; ) -> Task<Result<Self::Output>>;
fn erase(self) -> Arc<dyn AnyAgentTool> { fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self))) Arc::new(Erased(Arc::new(self)))
@ -710,6 +715,11 @@ where
pub struct Erased<T>(T); pub struct Erased<T>(T);
pub struct AgentToolOutput {
llm_output: LanguageModelToolResultContent,
raw_output: serde_json::Value,
}
pub trait AnyAgentTool { pub trait AnyAgentTool {
fn name(&self) -> SharedString; fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString; fn description(&self, cx: &mut App) -> SharedString;
@ -721,7 +731,7 @@ pub trait AnyAgentTool {
input: serde_json::Value, input: serde_json::Value,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>>; ) -> Task<Result<AgentToolOutput>>;
} }
impl<T> AnyAgentTool for Erased<Arc<T>> impl<T> AnyAgentTool for Erased<Arc<T>>
@ -756,12 +766,18 @@ where
input: serde_json::Value, input: serde_json::Value,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<AgentToolOutput>> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into); cx.spawn(async move |cx| {
match parsed_input { let input = serde_json::from_value(input)?;
Ok(input) => self.0.clone().run(input, event_stream, cx), let output = cx
Err(error) => Task::ready(Err(anyhow!(error))), .update(|cx| self.0.clone().run(input, event_stream, cx))?
} .await?;
let raw_output = serde_json::to_value(&output)?;
Ok(AgentToolOutput {
llm_output: output.into(),
raw_output,
})
})
} }
} }
@ -874,6 +890,12 @@ impl AgentResponseEventStream {
.ok(); .ok();
} }
fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
.ok();
}
fn send_stop(&self, reason: StopReason) { fn send_stop(&self, reason: StopReason) {
match reason { match reason {
StopReason::EndTurn => { StopReason::EndTurn => {
@ -903,13 +925,41 @@ impl AgentResponseEventStream {
#[derive(Clone)] #[derive(Clone)]
pub struct ToolCallEventStream { pub struct ToolCallEventStream {
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
kind: acp::ToolKind,
input: serde_json::Value,
stream: AgentResponseEventStream, stream: AgentResponseEventStream,
} }
impl ToolCallEventStream { impl ToolCallEventStream {
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self { #[cfg(test)]
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
(stream, ToolCallEventStreamReceiver(events_rx))
}
fn new(
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
stream: AgentResponseEventStream,
) -> Self {
Self { Self {
tool_use_id, tool_use_id: tool_use.id.clone(),
kind,
input: tool_use.input.clone(),
stream, stream,
} }
} }
@ -918,38 +968,52 @@ impl ToolCallEventStream {
self.stream.send_tool_call_update(&self.tool_use_id, fields); self.stream.send_tool_call_update(&self.tool_use_id, fields);
} }
pub fn authorize( pub fn send_diff(&self, diff: Entity<Diff>) {
&self, self.stream.send_tool_call_diff(ToolCallDiff {
title: String, tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
kind: acp::ToolKind, diff,
input: serde_json::Value, });
) -> impl use<> + Future<Output = Result<()>> { }
self.stream
.authorize_tool_call(&self.tool_use_id, title, kind, input) pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
self.stream.authorize_tool_call(
&self.tool_use_id,
title,
self.kind.clone(),
self.input.clone(),
)
} }
} }
#[cfg(test)] #[cfg(test)]
pub struct TestToolCallEventStream { pub struct ToolCallEventStreamReceiver(
stream: ToolCallEventStream, mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>, );
}
#[cfg(test)] #[cfg(test)]
impl TestToolCallEventStream { impl ToolCallEventStreamReceiver {
pub fn new() -> Self { pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
let (events_tx, events_rx) = let event = self.0.next().await;
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>(); if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
auth
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx)); } else {
panic!("Expected ToolCallAuthorization but got: {:?}", event);
Self {
stream,
_events_rx: events_rx,
} }
} }
}
pub fn stream(&self) -> ToolCallEventStream { #[cfg(test)]
self.stream.clone() impl std::ops::Deref for ToolCallEventStreamReceiver {
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(test)]
impl std::ops::DerefMut for ToolCallEventStreamReceiver {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
} }
} }

View file

@ -1,7 +1,9 @@
mod edit_file_tool;
mod find_path_tool; mod find_path_tool;
mod read_file_tool; mod read_file_tool;
mod thinking_tool; mod thinking_tool;
pub use edit_file_tool::*;
pub use find_path_tool::*; pub use find_path_tool::*;
pub use read_file_tool::*; pub use read_file_tool::*;
pub use thinking_tool::*; pub use thinking_tool::*;

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,8 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use gpui::{App, AppContext, Entity, SharedString, Task}; use gpui::{App, AppContext, Entity, SharedString, Task};
use language_model::LanguageModelToolResultContent;
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -8,8 +10,6 @@ use std::fmt::Write;
use std::{cmp, path::PathBuf, sync::Arc}; use std::{cmp, path::PathBuf, sync::Arc};
use util::paths::PathMatcher; use util::paths::PathMatcher;
use crate::{AgentTool, ToolCallEventStream};
/// Fast file path pattern matching tool that works with any codebase size /// Fast file path pattern matching tool that works with any codebase size
/// ///
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts" /// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
@ -39,8 +39,35 @@ pub struct FindPathToolInput {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct FindPathToolOutput { pub struct FindPathToolOutput {
paths: Vec<PathBuf>, offset: usize,
current_matches_page: Vec<PathBuf>,
all_matches_len: usize,
}
impl From<FindPathToolOutput> for LanguageModelToolResultContent {
fn from(output: FindPathToolOutput) -> Self {
if output.current_matches_page.is_empty() {
"No matches found".into()
} else {
let mut llm_output = format!("Found {} total matches.", output.all_matches_len);
if output.all_matches_len > RESULTS_PER_PAGE {
write!(
&mut llm_output,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
output.offset + 1,
output.offset + output.current_matches_page.len()
)
.unwrap();
}
for mat in output.current_matches_page {
write!(&mut llm_output, "\n{}", mat.display()).unwrap();
}
llm_output.into()
}
}
} }
const RESULTS_PER_PAGE: usize = 50; const RESULTS_PER_PAGE: usize = 50;
@ -57,6 +84,7 @@ impl FindPathTool {
impl AgentTool for FindPathTool { impl AgentTool for FindPathTool {
type Input = FindPathToolInput; type Input = FindPathToolInput;
type Output = FindPathToolOutput;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"find_path".into() "find_path".into()
@ -75,7 +103,7 @@ impl AgentTool for FindPathTool {
input: Self::Input, input: Self::Input,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<FindPathToolOutput>> {
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
cx.background_spawn(async move { cx.background_spawn(async move {
@ -113,26 +141,11 @@ impl AgentTool for FindPathTool {
..Default::default() ..Default::default()
}); });
if matches.is_empty() { Ok(FindPathToolOutput {
Ok("No matches found".into()) offset: input.offset,
} else { current_matches_page: paginated_matches.to_vec(),
let mut message = format!("Found {} total matches.", matches.len()); all_matches_len: matches.len(),
if matches.len() > RESULTS_PER_PAGE { })
write!(
&mut message,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
input.offset + 1,
input.offset + paginated_matches.len()
)
.unwrap();
}
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
}
}) })
} }
} }

View file

@ -1,10 +1,11 @@
use agent_client_protocol::{self as acp}; use agent_client_protocol::{self as acp};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Context, Result};
use assistant_tool::{outline, ActionLog}; use assistant_tool::{outline, ActionLog};
use gpui::{Entity, Task}; use gpui::{Entity, Task};
use indoc::formatdoc; use indoc::formatdoc;
use language::{Anchor, Point}; use language::{Anchor, Point};
use project::{AgentLocation, Project, WorktreeSettings}; use language_model::{LanguageModelImage, LanguageModelToolResultContent};
use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::Settings; use settings::Settings;
@ -59,6 +60,7 @@ impl ReadFileTool {
impl AgentTool for ReadFileTool { impl AgentTool for ReadFileTool {
type Input = ReadFileToolInput; type Input = ReadFileToolInput;
type Output = LanguageModelToolResultContent;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"read_file".into() "read_file".into()
@ -91,9 +93,9 @@ impl AgentTool for ReadFileTool {
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: Self::Input, input: Self::Input,
event_stream: ToolCallEventStream, _event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<LanguageModelToolResultContent>> {
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))); return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
}; };
@ -132,51 +134,27 @@ impl AgentTool for ReadFileTool {
let file_path = input.path.clone(); let file_path = input.path.clone();
event_stream.send_update(acp::ToolCallUpdateFields { if image_store::is_image_file(&self.project, &project_path, cx) {
locations: Some(vec![acp::ToolCallLocation { return cx.spawn(async move |cx| {
path: project_path.path.to_path_buf(), let image_entity: Entity<ImageItem> = cx
line: input.start_line, .update(|cx| {
// TODO (tracked): use full range self.project.update(cx, |project, cx| {
}]), project.open_image(project_path.clone(), cx)
..Default::default() })
}); })?
.await?;
// TODO (tracked): images let image =
// if image_store::is_image_file(&self.project, &project_path, cx) { image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
// let model = &self.thread.read(cx).selected_model;
// if !model.supports_images() { let language_model_image = cx
// return Task::ready(Err(anyhow!( .update(|cx| LanguageModelImage::from_image(image, cx))?
// "Attempted to read an image, but Zed doesn't currently support sending images to {}.", .await
// model.name().0 .context("processing image")?;
// )))
// .into();
// }
// return cx.spawn(async move |cx| -> Result<ToolResultOutput> { Ok(language_model_image.into())
// let image_entity: Entity<ImageItem> = cx });
// .update(|cx| { }
// self.project.update(cx, |project, cx| {
// project.open_image(project_path.clone(), cx)
// })
// })?
// .await?;
// let image =
// image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
// let language_model_image = cx
// .update(|cx| LanguageModelImage::from_image(image, cx))?
// .await
// .context("processing image")?;
// Ok(ToolResultOutput {
// content: ToolResultContent::Image(language_model_image),
// output: None,
// })
// });
// }
//
let project = self.project.clone(); let project = self.project.clone();
let action_log = self.action_log.clone(); let action_log = self.action_log.clone();
@ -244,7 +222,7 @@ impl AgentTool for ReadFileTool {
})?; })?;
} }
Ok(result) Ok(result.into())
} else { } else {
// No line ranges specified, so check file size to see if it's too big. // No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?; let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
@ -257,7 +235,7 @@ impl AgentTool for ReadFileTool {
log.buffer_read(buffer, cx); log.buffer_read(buffer, cx);
})?; })?;
Ok(result) Ok(result.into())
} else { } else {
// File is too big, so return the outline // File is too big, so return the outline
// and a suggestion to read again with line numbers. // and a suggestion to read again with line numbers.
@ -276,7 +254,8 @@ impl AgentTool for ReadFileTool {
Alternatively, you can fall back to the `grep` tool (if available) Alternatively, you can fall back to the `grep` tool (if available)
to search the file for specific content." to search the file for specific content."
}) }
.into())
} }
} }
}) })
@ -285,8 +264,6 @@ impl AgentTool for ReadFileTool {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::TestToolCallEventStream;
use super::*; use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher}; use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
@ -304,7 +281,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new(); let (event_stream, _) = ToolCallEventStream::test();
let result = cx let result = cx
.update(|cx| { .update(|cx| {
@ -313,7 +290,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, event_stream, cx)
}) })
.await; .await;
assert_eq!( assert_eq!(
@ -321,6 +298,7 @@ mod test {
"root/nonexistent_file.txt not found" "root/nonexistent_file.txt not found"
); );
} }
#[gpui::test] #[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) { async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx); init_test(cx);
@ -336,7 +314,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
@ -344,10 +321,10 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "This is a small file content"); assert_eq!(result.unwrap(), "This is a small file content".into());
} }
#[gpui::test] #[gpui::test]
@ -367,18 +344,18 @@ mod test {
language_registry.add(Arc::new(rust_lang())); language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new(); let result = cx
let content = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
path: "root/large_file.rs".into(), path: "root/large_file.rs".into(),
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await .await
.unwrap(); .unwrap();
let content = result.to_str().unwrap();
assert_eq!( assert_eq!(
content.lines().skip(4).take(6).collect::<Vec<_>>(), content.lines().skip(4).take(6).collect::<Vec<_>>(),
@ -399,10 +376,11 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await
let content = result.unwrap(); .unwrap();
let content = result.to_str().unwrap();
let expected_content = (0..1000) let expected_content = (0..1000)
.flat_map(|i| { .flat_map(|i| {
vec![ vec![
@ -438,7 +416,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx let result = cx
.update(|cx| { .update(|cx| {
let input = ReadFileToolInput { let input = ReadFileToolInput {
@ -446,10 +423,10 @@ mod test {
start_line: Some(2), start_line: Some(2),
end_line: Some(4), end_line: Some(4),
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4"); assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into());
} }
#[gpui::test] #[gpui::test]
@ -467,7 +444,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// start_line of 0 should be treated as 1 // start_line of 0 should be treated as 1
let result = cx let result = cx
@ -477,10 +453,10 @@ mod test {
start_line: Some(0), start_line: Some(0),
end_line: Some(2), end_line: Some(2),
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1\nLine 2"); assert_eq!(result.unwrap(), "Line 1\nLine 2".into());
// end_line of 0 should result in at least 1 line // end_line of 0 should result in at least 1 line
let result = cx let result = cx
@ -490,10 +466,10 @@ mod test {
start_line: Some(1), start_line: Some(1),
end_line: Some(0), end_line: Some(0),
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 1"); assert_eq!(result.unwrap(), "Line 1".into());
// when start_line > end_line, should still return at least 1 line // when start_line > end_line, should still return at least 1 line
let result = cx let result = cx
@ -503,10 +479,10 @@ mod test {
start_line: Some(3), start_line: Some(3),
end_line: Some(2), end_line: Some(2),
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert_eq!(result.unwrap(), "Line 3"); assert_eq!(result.unwrap(), "Line 3".into());
} }
fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {
@ -612,7 +588,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log)); let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// Reading a file outside the project worktree should fail // Reading a file outside the project worktree should fail
let result = cx let result = cx
@ -622,7 +597,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -638,7 +613,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -654,7 +629,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -669,7 +644,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -685,7 +660,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -700,7 +675,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -715,7 +690,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -731,11 +706,11 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!(result.is_ok(), "Should be able to read normal files"); assert!(result.is_ok(), "Should be able to read normal files");
assert_eq!(result.unwrap(), "Normal file content"); assert_eq!(result.unwrap(), "Normal file content".into());
// Path traversal attempts with .. should fail // Path traversal attempts with .. should fail
let result = cx let result = cx
@ -745,7 +720,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.run(input, event_stream.stream(), cx) tool.run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
assert!( assert!(
@ -826,7 +801,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone())); let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone())); let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
let event_stream = TestToolCallEventStream::new();
// Test reading allowed files in worktree1 // Test reading allowed files in worktree1
let result = cx let result = cx
@ -836,12 +810,15 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await .await
.unwrap(); .unwrap();
assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }"); assert_eq!(
result,
"fn main() { println!(\"Hello from worktree1\"); }".into()
);
// Test reading private file in worktree1 should fail // Test reading private file in worktree1 should fail
let result = cx let result = cx
@ -851,7 +828,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -872,7 +849,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -893,14 +870,14 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await .await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
result, result,
"export function greet() { return 'Hello from worktree2'; }" "export function greet() { return 'Hello from worktree2'; }".into()
); );
// Test reading private file in worktree2 should fail // Test reading private file in worktree2 should fail
@ -911,7 +888,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -932,7 +909,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;
@ -954,7 +931,7 @@ mod test {
start_line: None, start_line: None,
end_line: None, end_line: None,
}; };
tool.clone().run(input, event_stream.stream(), cx) tool.clone().run(input, ToolCallEventStream::test().0, cx)
}) })
.await; .await;

View file

@ -20,6 +20,7 @@ pub struct ThinkingTool;
impl AgentTool for ThinkingTool { impl AgentTool for ThinkingTool {
type Input = ThinkingToolInput; type Input = ThinkingToolInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"thinking".into() "thinking".into()

View file

@ -42,7 +42,7 @@ use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
use ::acp_thread::{ use ::acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
}; };
@ -732,7 +732,11 @@ impl AcpThreadView {
cx: &App, cx: &App,
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> { ) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
let entry = self.thread()?.read(cx).entries().get(entry_ix)?; let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
Some(entry.diffs().map(|diff| diff.multibuffer.clone())) Some(
entry
.diffs()
.map(|diff| diff.read(cx).multibuffer().clone()),
)
} }
fn authenticate( fn authenticate(
@ -1314,10 +1318,9 @@ impl AcpThreadView {
Empty.into_any_element() Empty.into_any_element()
} }
} }
ToolCallContent::Diff { ToolCallContent::Diff { diff, .. } => {
diff: Diff { multibuffer, .. }, self.render_diff_editor(&diff.read(cx).multibuffer())
.. }
} => self.render_diff_editor(multibuffer),
} }
} }

View file

@ -2,7 +2,7 @@ mod copy_path_tool;
mod create_directory_tool; mod create_directory_tool;
mod delete_path_tool; mod delete_path_tool;
mod diagnostics_tool; mod diagnostics_tool;
mod edit_agent; pub mod edit_agent;
mod edit_file_tool; mod edit_file_tool;
mod fetch_tool; mod fetch_tool;
mod find_path_tool; mod find_path_tool;
@ -14,7 +14,7 @@ mod open_tool;
mod project_notifications_tool; mod project_notifications_tool;
mod read_file_tool; mod read_file_tool;
mod schema; mod schema;
mod templates; pub mod templates;
mod terminal_tool; mod terminal_tool;
mod thinking_tool; mod thinking_tool;
mod ui; mod ui;

View file

@ -29,7 +29,6 @@ use serde::{Deserialize, Serialize};
use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll}; use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff}; use streaming_diff::{CharOperation, StreamingDiff};
use streaming_fuzzy_matcher::StreamingFuzzyMatcher; use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
use util::debug_panic;
#[derive(Serialize)] #[derive(Serialize)]
struct CreateFilePromptTemplate { struct CreateFilePromptTemplate {
@ -682,11 +681,6 @@ impl EditAgent {
if last_message.content.is_empty() { if last_message.content.is_empty() {
conversation.messages.pop(); conversation.messages.pop();
} }
} else {
debug_panic!(
"Last message must be an Assistant tool calling! Got {:?}",
last_message.content
);
} }
} }

View file

@ -120,8 +120,6 @@ struct PartialInput {
display_description: String, display_description: String,
} }
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for EditFileTool { impl Tool for EditFileTool {
fn name(&self) -> String { fn name(&self) -> String {
"edit_file".into() "edit_file".into()
@ -211,22 +209,6 @@ impl Tool for EditFileTool {
} }
} }
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
}
fn run( fn run(
self: Arc<Self>, self: Arc<Self>,
input: serde_json::Value, input: serde_json::Value,
@ -1370,73 +1352,6 @@ mod tests {
assert_eq!(actual, expected); assert_eq!(actual, expected);
} }
#[test]
fn still_streaming_ui_text_with_path() {
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
#[test]
fn still_streaming_ui_text_with_null() {
let input = serde_json::Value::Null;
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
fn init_test(cx: &mut TestAppContext) { fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| { cx.update(|cx| {
let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);

View file

@ -297,6 +297,12 @@ impl From<String> for LanguageModelToolResultContent {
} }
} }
impl From<LanguageModelImage> for LanguageModelToolResultContent {
fn from(image: LanguageModelImage) -> Self {
Self::Image(image)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent { pub enum MessageContent {
Text(String), Text(String),