agent2: Port edit_file tool (#35844)

TODO:
- [x] Authorization
- [x] Restore tests

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Agus Zubiaga 2025-08-08 09:43:53 -03:00 committed by GitHub
parent d705585a2e
commit 2526dcb5a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 2075 additions and 414 deletions

4
Cargo.lock generated
View file

@ -158,8 +158,10 @@ dependencies = [
"acp_thread",
"agent-client-protocol",
"agent_servers",
"agent_settings",
"anyhow",
"assistant_tool",
"assistant_tools",
"client",
"clock",
"cloud_llm_client",
@ -177,6 +179,8 @@ dependencies = [
"language_model",
"language_models",
"log",
"lsp",
"paths",
"pretty_assertions",
"project",
"prompt_store",

View file

@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
agent-client-protocol = { version = "0.0.23" }
agent-client-protocol = "0.0.23"
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"

View file

@ -1,18 +1,17 @@
mod connection;
mod diff;
pub use connection::*;
pub use diff::*;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
use assistant_tool::ActionLog;
use buffer_diff::BufferDiff;
use editor::{Bias, MultiBuffer, PathKey};
use editor::Bias;
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
use itertools::Itertools;
use language::{
Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
text_diff,
};
use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
use markdown::Markdown;
use project::{AgentLocation, Project};
use std::collections::HashMap;
@ -140,7 +139,7 @@ impl AgentThreadEntry {
}
}
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
if let AgentThreadEntry::ToolCall(call) = self {
itertools::Either::Left(call.diffs())
} else {
@ -249,7 +248,7 @@ impl ToolCall {
}
}
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
self.content.iter().filter_map(|content| match content {
ToolCallContent::ContentBlock { .. } => None,
ToolCallContent::Diff { diff } => Some(diff),
@ -389,7 +388,7 @@ impl ContentBlock {
#[derive(Debug)]
pub enum ToolCallContent {
ContentBlock { content: ContentBlock },
Diff { diff: Diff },
Diff { diff: Entity<Diff> },
}
impl ToolCallContent {
@ -403,7 +402,7 @@ impl ToolCallContent {
content: ContentBlock::new(content, &language_registry, cx),
},
acp::ToolCallContent::Diff { diff } => Self::Diff {
diff: Diff::from_acp(diff, language_registry, cx),
diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)),
},
}
}
@ -411,108 +410,11 @@ impl ToolCallContent {
pub fn to_markdown(&self, cx: &App) -> String {
match self {
Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
Self::Diff { diff } => diff.to_markdown(cx),
Self::Diff { diff } => diff.read(cx).to_markdown(cx),
}
}
}
#[derive(Debug)]
pub struct Diff {
pub multibuffer: Entity<MultiBuffer>,
pub path: PathBuf,
_task: Task<Result<()>>,
}
impl Diff {
pub fn from_acp(
diff: acp::Diff,
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
let task = cx.spawn({
let multibuffer = multibuffer.clone();
let path = path.clone();
async move |cx| {
let language = language_registry
.language_for_file_path(&path)
.await
.log_err();
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
buffer.set_language(language, cx);
buffer.snapshot()
})?;
buffer_diff
.update(cx, |diff, cx| {
diff.set_base_text(
old_buffer_snapshot,
Some(language_registry),
new_buffer_snapshot,
cx,
)
})?
.await?;
multibuffer
.update(cx, |multibuffer, cx| {
let hunk_ranges = {
let buffer = new_buffer.read(cx);
let diff = buffer_diff.read(cx);
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>()
};
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&new_buffer, cx),
new_buffer.clone(),
hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff, cx);
})
.log_err();
anyhow::Ok(())
}
});
Self {
multibuffer,
path,
_task: task,
}
}
fn to_markdown(&self, cx: &App) -> String {
let buffer_text = self
.multibuffer
.read(cx)
.all_buffers()
.iter()
.map(|buffer| buffer.read(cx).text())
.join("\n");
format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
}
}
#[derive(Debug, Default)]
pub struct Plan {
pub entries: Vec<PlanEntry>,
@ -823,6 +725,21 @@ impl AcpThread {
Ok(())
}
pub fn set_tool_call_diff(
&mut self,
tool_call_id: &acp::ToolCallId,
diff: Entity<Diff>,
cx: &mut Context<Self>,
) -> Result<()> {
let (ix, current_call) = self
.tool_call_mut(tool_call_id)
.context("Tool call not found")?;
current_call.content.clear();
current_call.content.push(ToolCallContent::Diff { diff });
cx.emit(AcpThreadEvent::EntryUpdated(ix));
Ok(())
}
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
let status = ToolCallStatus::Allowed {

View file

@ -0,0 +1,388 @@
use agent_client_protocol as acp;
use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey};
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task};
use itertools::Itertools;
use language::{
Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _, Point, Rope, TextBuffer,
};
use std::{
cmp::Reverse,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
};
use util::ResultExt;
pub enum Diff {
Pending(PendingDiff),
Finalized(FinalizedDiff),
}
impl Diff {
pub fn from_acp(
diff: acp::Diff,
language_registry: Arc<LanguageRegistry>,
cx: &mut Context<Self>,
) -> Self {
let acp::Diff {
path,
old_text,
new_text,
} = diff;
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
let task = cx.spawn({
let multibuffer = multibuffer.clone();
let path = path.clone();
async move |_, cx| {
let language = language_registry
.language_for_file_path(&path)
.await
.log_err();
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
buffer.set_language(language, cx);
buffer.snapshot()
})?;
buffer_diff
.update(cx, |diff, cx| {
diff.set_base_text(
old_buffer_snapshot,
Some(language_registry),
new_buffer_snapshot,
cx,
)
})?
.await?;
multibuffer
.update(cx, |multibuffer, cx| {
let hunk_ranges = {
let buffer = new_buffer.read(cx);
let diff = buffer_diff.read(cx);
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>()
};
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&new_buffer, cx),
new_buffer.clone(),
hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff, cx);
})
.log_err();
anyhow::Ok(())
}
});
Self::Finalized(FinalizedDiff {
multibuffer,
path,
_update_diff: task,
})
}
pub fn new(buffer: Entity<Buffer>, cx: &mut Context<Self>) -> Self {
let buffer_snapshot = buffer.read(cx).snapshot();
let base_text = buffer_snapshot.text();
let language_registry = buffer.read(cx).language_registry();
let text_snapshot = buffer.read(cx).text_snapshot();
let buffer_diff = cx.new(|cx| {
let mut diff = BufferDiff::new(&text_snapshot, cx);
let _ = diff.set_base_text(
buffer_snapshot.clone(),
language_registry,
text_snapshot,
cx,
);
diff
});
let multibuffer = cx.new(|cx| {
let mut multibuffer = MultiBuffer::without_headers(Capability::ReadOnly);
multibuffer.add_diff(buffer_diff.clone(), cx);
multibuffer
});
Self::Pending(PendingDiff {
multibuffer,
base_text: Arc::new(base_text),
_subscription: cx.observe(&buffer, |this, _, cx| {
if let Diff::Pending(diff) = this {
diff.update(cx);
}
}),
buffer,
diff: buffer_diff,
revealed_ranges: Vec::new(),
update_diff: Task::ready(Ok(())),
})
}
pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Self>) {
if let Self::Pending(diff) = self {
diff.reveal_range(range, cx);
}
}
pub fn finalize(&mut self, cx: &mut Context<Self>) {
if let Self::Pending(diff) = self {
*self = Self::Finalized(diff.finalize(cx));
}
}
pub fn multibuffer(&self) -> &Entity<MultiBuffer> {
match self {
Self::Pending(PendingDiff { multibuffer, .. }) => multibuffer,
Self::Finalized(FinalizedDiff { multibuffer, .. }) => multibuffer,
}
}
pub fn to_markdown(&self, cx: &App) -> String {
let buffer_text = self
.multibuffer()
.read(cx)
.all_buffers()
.iter()
.map(|buffer| buffer.read(cx).text())
.join("\n");
let path = match self {
Diff::Pending(PendingDiff { buffer, .. }) => {
buffer.read(cx).file().map(|file| file.path().as_ref())
}
Diff::Finalized(FinalizedDiff { path, .. }) => Some(path.as_path()),
};
format!(
"Diff: {}\n```\n{}\n```\n",
path.unwrap_or(Path::new("untitled")).display(),
buffer_text
)
}
}
pub struct PendingDiff {
multibuffer: Entity<MultiBuffer>,
base_text: Arc<String>,
buffer: Entity<Buffer>,
diff: Entity<BufferDiff>,
revealed_ranges: Vec<Range<Anchor>>,
_subscription: Subscription,
update_diff: Task<Result<()>>,
}
impl PendingDiff {
pub fn update(&mut self, cx: &mut Context<Diff>) {
let buffer = self.buffer.clone();
let buffer_diff = self.diff.clone();
let base_text = self.base_text.clone();
self.update_diff = cx.spawn(async move |diff, cx| {
let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?;
let diff_snapshot = BufferDiff::update_diff(
buffer_diff.clone(),
text_snapshot.clone(),
Some(base_text),
false,
false,
None,
None,
cx,
)
.await?;
buffer_diff.update(cx, |diff, cx| {
diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
})?;
diff.update(cx, |diff, cx| {
if let Diff::Pending(diff) = diff {
diff.update_visible_ranges(cx);
}
})
});
}
pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Diff>) {
self.revealed_ranges.push(range);
self.update_visible_ranges(cx);
}
fn finalize(&self, cx: &mut Context<Diff>) -> FinalizedDiff {
let ranges = self.excerpt_ranges(cx);
let base_text = self.base_text.clone();
let language_registry = self.buffer.read(cx).language_registry().clone();
let path = self
.buffer
.read(cx)
.file()
.map(|file| file.path().as_ref())
.unwrap_or(Path::new("untitled"))
.into();
// Replace the buffer in the multibuffer with the snapshot
let buffer = cx.new(|cx| {
let language = self.buffer.read(cx).language().cloned();
let buffer = TextBuffer::new_normalized(
0,
cx.entity_id().as_non_zero_u64().into(),
self.buffer.read(cx).line_ending(),
self.buffer.read(cx).as_rope().clone(),
);
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
buffer.set_language(language, cx);
buffer
});
let buffer_diff = cx.spawn({
let buffer = buffer.clone();
let language_registry = language_registry.clone();
async move |_this, cx| {
build_buffer_diff(base_text, &buffer, language_registry, cx).await
}
});
let update_diff = cx.spawn(async move |this, cx| {
let buffer_diff = buffer_diff.await?;
this.update(cx, |this, cx| {
this.multibuffer().update(cx, |multibuffer, cx| {
let path_key = PathKey::for_buffer(&buffer, cx);
multibuffer.clear(cx);
multibuffer.set_excerpts_for_path(
path_key,
buffer,
ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
multibuffer.add_diff(buffer_diff.clone(), cx);
});
cx.notify();
})
});
FinalizedDiff {
path,
multibuffer: self.multibuffer.clone(),
_update_diff: update_diff,
}
}
fn update_visible_ranges(&mut self, cx: &mut Context<Diff>) {
let ranges = self.excerpt_ranges(cx);
self.multibuffer.update(cx, |multibuffer, cx| {
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&self.buffer, cx),
self.buffer.clone(),
ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
let end = multibuffer.len(cx);
Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
});
cx.notify();
}
fn excerpt_ranges(&self, cx: &App) -> Vec<Range<Point>> {
let buffer = self.buffer.read(cx);
let diff = self.diff.read(cx);
let mut ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
.collect::<Vec<_>>();
ranges.extend(
self.revealed_ranges
.iter()
.map(|range| range.to_point(&buffer)),
);
ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
// Merge adjacent ranges
let mut ranges = ranges.into_iter().peekable();
let mut merged_ranges = Vec::new();
while let Some(mut range) = ranges.next() {
while let Some(next_range) = ranges.peek() {
if range.end >= next_range.start {
range.end = range.end.max(next_range.end);
ranges.next();
} else {
break;
}
}
merged_ranges.push(range);
}
merged_ranges
}
}
pub struct FinalizedDiff {
path: PathBuf,
multibuffer: Entity<MultiBuffer>,
_update_diff: Task<Result<()>>,
}
async fn build_buffer_diff(
old_text: Arc<String>,
buffer: &Entity<Buffer>,
language_registry: Option<Arc<LanguageRegistry>>,
cx: &mut AsyncApp,
) -> Result<Entity<BufferDiff>> {
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
let old_text_rope = cx
.background_spawn({
let old_text = old_text.clone();
async move { Rope::from(old_text.as_str()) }
})
.await;
let base_buffer = cx
.update(|cx| {
Buffer::build_snapshot(
old_text_rope,
buffer.language().cloned(),
language_registry,
cx,
)
})?
.await;
let diff_snapshot = cx
.update(|cx| {
BufferDiffSnapshot::new_with_base_buffer(
buffer.text.clone(),
Some(old_text),
base_buffer,
cx,
)
})?
.await;
let secondary_diff = cx.new(|cx| {
let mut diff = BufferDiff::new(&buffer, cx);
diff.set_snapshot(diff_snapshot.clone(), &buffer, cx);
diff
})?;
cx.new(|cx| {
let mut diff = BufferDiff::new(&buffer.text, cx);
diff.set_snapshot(diff_snapshot, &buffer, cx);
diff.set_secondary_diff(secondary_diff);
diff
})
}

View file

@ -15,8 +15,10 @@ workspace = true
acp_thread.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
agent_settings.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
fs.workspace = true
@ -29,6 +31,7 @@ language.workspace = true
language_model.workspace = true
language_models.workspace = true
log.workspace = true
paths.workspace = true
project.workspace = true
prompt_store.workspace = true
rust-embed.workspace = true
@ -53,6 +56,7 @@ gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }
lsp = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }

View file

@ -1,5 +1,5 @@
use crate::{templates::Templates, AgentResponseEvent, Thread};
use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
use crate::{EditFileTool, FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
use acp_thread::ModelSelector;
use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result};
@ -412,11 +412,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
anyhow!("No default model configured. Please configure a default model in settings.")
})?;
let thread = cx.new(|_| {
let thread = cx.new(|cx| {
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(EditFileTool::new(cx.entity()));
thread
});
@ -564,6 +565,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
)
})??;
}
AgentResponseEvent::ToolCallDiff(tool_call_diff) => {
acp_thread.update(cx, |thread, cx| {
thread.set_tool_call_diff(
&tool_call_diff.tool_call_id,
tool_call_diff.diff,
cx,
)
})??;
}
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });

View file

@ -9,5 +9,6 @@ mod tests;
pub use agent::*;
pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*;
pub use tools::*;

View file

@ -1,5 +1,4 @@
use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection;
use agent_client_protocol::{self as acp};
use anyhow::Result;
@ -273,7 +272,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: None
output: Some("Allowed".into())
}),
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),

View file

@ -14,6 +14,7 @@ pub struct EchoTool;
impl AgentTool for EchoTool {
type Input = EchoToolInput;
type Output = String;
fn name(&self) -> SharedString {
"echo".into()
@ -48,6 +49,7 @@ pub struct DelayTool;
impl AgentTool for DelayTool {
type Input = DelayToolInput;
type Output = String;
fn name(&self) -> SharedString {
"delay".into()
@ -84,6 +86,7 @@ pub struct ToolRequiringPermission;
impl AgentTool for ToolRequiringPermission {
type Input = ToolRequiringPermissionInput;
type Output = String;
fn name(&self) -> SharedString {
"tool_requiring_permission".into()
@ -99,14 +102,11 @@ impl AgentTool for ToolRequiringPermission {
fn run(
self: Arc<Self>,
input: Self::Input,
_input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>
where
Self: Sized,
{
let auth_check = self.authorize(input, event_stream);
) -> Task<Result<String>> {
let auth_check = event_stream.authorize("Authorize?".into());
cx.foreground_executor().spawn(async move {
auth_check.await?;
Ok("Allowed".to_string())
@ -121,6 +121,7 @@ pub struct InfiniteTool;
impl AgentTool for InfiniteTool {
type Input = InfiniteToolInput;
type Output = String;
fn name(&self) -> SharedString {
"infinite".into()
@ -171,19 +172,20 @@ pub struct WordListTool;
impl AgentTool for WordListTool {
type Input = WordListInput;
type Output = String;
fn name(&self) -> SharedString {
"word_list".into()
}
fn initial_title(&self, _input: Self::Input) -> SharedString {
"List of random words".into()
}
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Self::Input) -> SharedString {
"List of random words".into()
}
fn run(
self: Arc<Self>,
_input: Self::Input,

View file

@ -1,4 +1,5 @@
use crate::templates::{SystemPromptTemplate, Template, Templates};
use crate::{SystemPromptTemplate, Template, Templates};
use acp_thread::Diff;
use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{adapt_schema_to_format, ActionLog};
@ -103,6 +104,7 @@ pub enum AgentResponseEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
ToolCallDiff(ToolCallDiff),
Stop(acp::StopReason),
}
@ -113,6 +115,12 @@ pub struct ToolCallAuthorization {
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
#[derive(Debug)]
pub struct ToolCallDiff {
pub tool_call_id: acp::ToolCallId,
pub diff: Entity<acp_thread::Diff>,
}
pub struct Thread {
messages: Vec<AgentMessage>,
completion_mode: CompletionMode,
@ -125,12 +133,13 @@ pub struct Thread {
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl Thread {
pub fn new(
_project: Entity<Project>,
project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
@ -145,10 +154,19 @@ impl Thread {
project_context,
templates,
selected_model: default_model,
project,
action_log,
}
}
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn set_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@ -315,10 +333,6 @@ impl Thread {
events_rx
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
@ -490,15 +504,33 @@ impl Thread {
}));
};
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
let tool_event_stream =
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
tool_event_stream.send_update(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
let supports_images = self.selected_model.supports_images();
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
Some(cx.foreground_executor().spawn(async move {
match tool_result.await {
Ok(tool_output) => LanguageModelToolResult {
let tool_result = tool_result.await.and_then(|output| {
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
if !supports_images {
return Err(anyhow!(
"Attempted to read an image, but this model doesn't support it.",
));
}
}
Ok(output)
});
match tool_result {
Ok(output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
output: None,
content: output.llm_output,
output: Some(output.raw_output),
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
@ -511,24 +543,6 @@ impl Thread {
}))
}
fn run_tool(
&self,
tool: Arc<dyn AnyAgentTool>,
tool_use: LanguageModelToolUse,
event_stream: AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Task<Result<String>> {
cx.spawn(async move |_this, cx| {
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
tool_event_stream.send_update(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
.await
})
}
fn handle_tool_use_json_parse_error_event(
&mut self,
tool_use_id: LanguageModelToolUseId,
@ -572,7 +586,7 @@ impl Thread {
self.messages.last_mut().unwrap()
}
fn build_completion_request(
pub(crate) fn build_completion_request(
&self,
completion_intent: CompletionIntent,
cx: &mut App,
@ -662,6 +676,7 @@ where
Self: 'static + Sized,
{
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
fn name(&self) -> SharedString;
@ -685,23 +700,13 @@ where
schemars::schema_for!(Self::Input)
}
/// Allows the tool to authorize a given tool call with the user if necessary
fn authorize(
&self,
input: Self::Input,
event_stream: ToolCallEventStream,
) -> impl use<Self> + Future<Output = Result<()>> {
let json_input = serde_json::json!(&input);
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>;
) -> Task<Result<Self::Output>>;
fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self)))
@ -710,6 +715,11 @@ where
pub struct Erased<T>(T);
pub struct AgentToolOutput {
llm_output: LanguageModelToolResultContent,
raw_output: serde_json::Value,
}
pub trait AnyAgentTool {
fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString;
@ -721,7 +731,7 @@ pub trait AnyAgentTool {
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>;
) -> Task<Result<AgentToolOutput>>;
}
impl<T> AnyAgentTool for Erased<Arc<T>>
@ -756,12 +766,18 @@ where
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
match parsed_input {
Ok(input) => self.0.clone().run(input, event_stream, cx),
Err(error) => Task::ready(Err(anyhow!(error))),
}
) -> Task<Result<AgentToolOutput>> {
cx.spawn(async move |cx| {
let input = serde_json::from_value(input)?;
let output = cx
.update(|cx| self.0.clone().run(input, event_stream, cx))?
.await?;
let raw_output = serde_json::to_value(&output)?;
Ok(AgentToolOutput {
llm_output: output.into(),
raw_output,
})
})
}
}
@ -874,6 +890,12 @@ impl AgentResponseEventStream {
.ok();
}
fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
.ok();
}
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {
@ -903,13 +925,41 @@ impl AgentResponseEventStream {
#[derive(Clone)]
pub struct ToolCallEventStream {
tool_use_id: LanguageModelToolUseId,
kind: acp::ToolKind,
input: serde_json::Value,
stream: AgentResponseEventStream,
}
impl ToolCallEventStream {
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
#[cfg(test)]
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
id: "test_id".into(),
name: "test_tool".into(),
raw_input: String::new(),
input: serde_json::Value::Null,
is_input_complete: true,
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
);
(stream, ToolCallEventStreamReceiver(events_rx))
}
fn new(
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
stream: AgentResponseEventStream,
) -> Self {
Self {
tool_use_id,
tool_use_id: tool_use.id.clone(),
kind,
input: tool_use.input.clone(),
stream,
}
}
@ -918,38 +968,52 @@ impl ToolCallEventStream {
self.stream.send_tool_call_update(&self.tool_use_id, fields);
}
pub fn authorize(
&self,
title: String,
kind: acp::ToolKind,
input: serde_json::Value,
) -> impl use<> + Future<Output = Result<()>> {
self.stream
.authorize_tool_call(&self.tool_use_id, title, kind, input)
pub fn send_diff(&self, diff: Entity<Diff>) {
self.stream.send_tool_call_diff(ToolCallDiff {
tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
diff,
});
}
pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
self.stream.authorize_tool_call(
&self.tool_use_id,
title,
self.kind.clone(),
self.input.clone(),
)
}
}
#[cfg(test)]
pub struct TestToolCallEventStream {
stream: ToolCallEventStream,
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
}
pub struct ToolCallEventStreamReceiver(
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
#[cfg(test)]
impl TestToolCallEventStream {
pub fn new() -> Self {
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
Self {
stream,
_events_rx: events_rx,
impl ToolCallEventStreamReceiver {
pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
let event = self.0.next().await;
if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
auth
} else {
panic!("Expected ToolCallAuthorization but got: {:?}", event);
}
}
}
pub fn stream(&self) -> ToolCallEventStream {
self.stream.clone()
#[cfg(test)]
impl std::ops::Deref for ToolCallEventStreamReceiver {
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(test)]
impl std::ops::DerefMut for ToolCallEventStreamReceiver {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

View file

@ -1,7 +1,9 @@
mod edit_file_tool;
mod find_path_tool;
mod read_file_tool;
mod thinking_tool;
pub use edit_file_tool::*;
pub use find_path_tool::*;
pub use read_file_tool::*;
pub use thinking_tool::*;

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,8 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use gpui::{App, AppContext, Entity, SharedString, Task};
use language_model::LanguageModelToolResultContent;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -8,8 +10,6 @@ use std::fmt::Write;
use std::{cmp, path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
use crate::{AgentTool, ToolCallEventStream};
/// Fast file path pattern matching tool that works with any codebase size
///
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
@ -39,8 +39,35 @@ pub struct FindPathToolInput {
}
#[derive(Debug, Serialize, Deserialize)]
struct FindPathToolOutput {
paths: Vec<PathBuf>,
pub struct FindPathToolOutput {
offset: usize,
current_matches_page: Vec<PathBuf>,
all_matches_len: usize,
}
impl From<FindPathToolOutput> for LanguageModelToolResultContent {
fn from(output: FindPathToolOutput) -> Self {
if output.current_matches_page.is_empty() {
"No matches found".into()
} else {
let mut llm_output = format!("Found {} total matches.", output.all_matches_len);
if output.all_matches_len > RESULTS_PER_PAGE {
write!(
&mut llm_output,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
output.offset + 1,
output.offset + output.current_matches_page.len()
)
.unwrap();
}
for mat in output.current_matches_page {
write!(&mut llm_output, "\n{}", mat.display()).unwrap();
}
llm_output.into()
}
}
}
const RESULTS_PER_PAGE: usize = 50;
@ -57,6 +84,7 @@ impl FindPathTool {
impl AgentTool for FindPathTool {
type Input = FindPathToolInput;
type Output = FindPathToolOutput;
fn name(&self) -> SharedString {
"find_path".into()
@ -75,7 +103,7 @@ impl AgentTool for FindPathTool {
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
) -> Task<Result<FindPathToolOutput>> {
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
cx.background_spawn(async move {
@ -113,26 +141,11 @@ impl AgentTool for FindPathTool {
..Default::default()
});
if matches.is_empty() {
Ok("No matches found".into())
} else {
let mut message = format!("Found {} total matches.", matches.len());
if matches.len() > RESULTS_PER_PAGE {
write!(
&mut message,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
input.offset + 1,
input.offset + paginated_matches.len()
)
.unwrap();
}
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
}
Ok(FindPathToolOutput {
offset: input.offset,
current_matches_page: paginated_matches.to_vec(),
all_matches_len: matches.len(),
})
})
}
}

View file

@ -1,10 +1,11 @@
use agent_client_protocol::{self as acp};
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use assistant_tool::{outline, ActionLog};
use gpui::{Entity, Task};
use indoc::formatdoc;
use language::{Anchor, Point};
use project::{AgentLocation, Project, WorktreeSettings};
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
@ -59,6 +60,7 @@ impl ReadFileTool {
impl AgentTool for ReadFileTool {
type Input = ReadFileToolInput;
type Output = LanguageModelToolResultContent;
fn name(&self) -> SharedString {
"read_file".into()
@ -91,9 +93,9 @@ impl AgentTool for ReadFileTool {
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
) -> Task<Result<LanguageModelToolResultContent>> {
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
};
@ -132,51 +134,27 @@ impl AgentTool for ReadFileTool {
let file_path = input.path.clone();
event_stream.send_update(acp::ToolCallUpdateFields {
locations: Some(vec![acp::ToolCallLocation {
path: project_path.path.to_path_buf(),
line: input.start_line,
// TODO (tracked): use full range
}]),
..Default::default()
});
if image_store::is_image_file(&self.project, &project_path, cx) {
return cx.spawn(async move |cx| {
let image_entity: Entity<ImageItem> = cx
.update(|cx| {
self.project.update(cx, |project, cx| {
project.open_image(project_path.clone(), cx)
})
})?
.await?;
// TODO (tracked): images
// if image_store::is_image_file(&self.project, &project_path, cx) {
// let model = &self.thread.read(cx).selected_model;
let image =
image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
// if !model.supports_images() {
// return Task::ready(Err(anyhow!(
// "Attempted to read an image, but Zed doesn't currently support sending images to {}.",
// model.name().0
// )))
// .into();
// }
let language_model_image = cx
.update(|cx| LanguageModelImage::from_image(image, cx))?
.await
.context("processing image")?;
// return cx.spawn(async move |cx| -> Result<ToolResultOutput> {
// let image_entity: Entity<ImageItem> = cx
// .update(|cx| {
// self.project.update(cx, |project, cx| {
// project.open_image(project_path.clone(), cx)
// })
// })?
// .await?;
// let image =
// image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
// let language_model_image = cx
// .update(|cx| LanguageModelImage::from_image(image, cx))?
// .await
// .context("processing image")?;
// Ok(ToolResultOutput {
// content: ToolResultContent::Image(language_model_image),
// output: None,
// })
// });
// }
//
Ok(language_model_image.into())
});
}
let project = self.project.clone();
let action_log = self.action_log.clone();
@ -244,7 +222,7 @@ impl AgentTool for ReadFileTool {
})?;
}
Ok(result)
Ok(result.into())
} else {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
@ -257,7 +235,7 @@ impl AgentTool for ReadFileTool {
log.buffer_read(buffer, cx);
})?;
Ok(result)
Ok(result.into())
} else {
// File is too big, so return the outline
// and a suggestion to read again with line numbers.
@ -276,7 +254,8 @@ impl AgentTool for ReadFileTool {
Alternatively, you can fall back to the `grep` tool (if available)
to search the file for specific content."
})
}
.into())
}
}
})
@ -285,8 +264,6 @@ impl AgentTool for ReadFileTool {
#[cfg(test)]
mod test {
use crate::TestToolCallEventStream;
use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
@ -304,7 +281,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let (event_stream, _) = ToolCallEventStream::test();
let result = cx
.update(|cx| {
@ -313,7 +290,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, event_stream, cx)
})
.await;
assert_eq!(
@ -321,6 +298,7 @@ mod test {
"root/nonexistent_file.txt not found"
);
}
#[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx);
@ -336,7 +314,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@ -344,10 +321,10 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "This is a small file content");
assert_eq!(result.unwrap(), "This is a small file content".into());
}
#[gpui::test]
@ -367,18 +344,18 @@ mod test {
language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let content = cx
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".into(),
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
let content = result.to_str().unwrap();
assert_eq!(
content.lines().skip(4).take(6).collect::<Vec<_>>(),
@ -399,10 +376,11 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
let content = result.unwrap();
.await
.unwrap();
let content = result.to_str().unwrap();
let expected_content = (0..1000)
.flat_map(|i| {
vec![
@ -438,7 +416,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@ -446,10 +423,10 @@ mod test {
start_line: Some(2),
end_line: Some(4),
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into());
}
#[gpui::test]
@ -467,7 +444,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// start_line of 0 should be treated as 1
let result = cx
@ -477,10 +453,10 @@ mod test {
start_line: Some(0),
end_line: Some(2),
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1\nLine 2");
assert_eq!(result.unwrap(), "Line 1\nLine 2".into());
// end_line of 0 should result in at least 1 line
let result = cx
@ -490,10 +466,10 @@ mod test {
start_line: Some(1),
end_line: Some(0),
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 1");
assert_eq!(result.unwrap(), "Line 1".into());
// when start_line > end_line, should still return at least 1 line
let result = cx
@ -503,10 +479,10 @@ mod test {
start_line: Some(3),
end_line: Some(2),
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(result.unwrap(), "Line 3");
assert_eq!(result.unwrap(), "Line 3".into());
}
fn init_test(cx: &mut TestAppContext) {
@ -612,7 +588,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
let event_stream = TestToolCallEventStream::new();
// Reading a file outside the project worktree should fail
let result = cx
@ -622,7 +597,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -638,7 +613,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -654,7 +629,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -669,7 +644,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -685,7 +660,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -700,7 +675,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -715,7 +690,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -731,11 +706,11 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_ok(), "Should be able to read normal files");
assert_eq!(result.unwrap(), "Normal file content");
assert_eq!(result.unwrap(), "Normal file content".into());
// Path traversal attempts with .. should fail
let result = cx
@ -745,7 +720,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.run(input, event_stream.stream(), cx)
tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@ -826,7 +801,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
let event_stream = TestToolCallEventStream::new();
// Test reading allowed files in worktree1
let result = cx
@ -836,12 +810,15 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }");
assert_eq!(
result,
"fn main() { println!(\"Hello from worktree1\"); }".into()
);
// Test reading private file in worktree1 should fail
let result = cx
@ -851,7 +828,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -872,7 +849,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -893,14 +870,14 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
assert_eq!(
result,
"export function greet() { return 'Hello from worktree2'; }"
"export function greet() { return 'Hello from worktree2'; }".into()
);
// Test reading private file in worktree2 should fail
@ -911,7 +888,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -932,7 +909,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@ -954,7 +931,7 @@ mod test {
start_line: None,
end_line: None,
};
tool.clone().run(input, event_stream.stream(), cx)
tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;

View file

@ -20,6 +20,7 @@ pub struct ThinkingTool;
impl AgentTool for ThinkingTool {
type Input = ThinkingToolInput;
type Output = String;
fn name(&self) -> SharedString {
"thinking".into()

View file

@ -42,7 +42,7 @@ use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
use ::acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff,
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
};
@ -732,7 +732,11 @@ impl AcpThreadView {
cx: &App,
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
Some(
entry
.diffs()
.map(|diff| diff.read(cx).multibuffer().clone()),
)
}
fn authenticate(
@ -1314,10 +1318,9 @@ impl AcpThreadView {
Empty.into_any_element()
}
}
ToolCallContent::Diff {
diff: Diff { multibuffer, .. },
..
} => self.render_diff_editor(multibuffer),
ToolCallContent::Diff { diff, .. } => {
self.render_diff_editor(&diff.read(cx).multibuffer())
}
}
}

View file

@ -2,7 +2,7 @@ mod copy_path_tool;
mod create_directory_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_agent;
pub mod edit_agent;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
@ -14,7 +14,7 @@ mod open_tool;
mod project_notifications_tool;
mod read_file_tool;
mod schema;
mod templates;
pub mod templates;
mod terminal_tool;
mod thinking_tool;
mod ui;

View file

@ -29,7 +29,6 @@ use serde::{Deserialize, Serialize};
use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
use util::debug_panic;
#[derive(Serialize)]
struct CreateFilePromptTemplate {
@ -682,11 +681,6 @@ impl EditAgent {
if last_message.content.is_empty() {
conversation.messages.pop();
}
} else {
debug_panic!(
"Last message must be an Assistant tool calling! Got {:?}",
last_message.content
);
}
}

View file

@ -120,8 +120,6 @@ struct PartialInput {
display_description: String,
}
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for EditFileTool {
fn name(&self) -> String {
"edit_file".into()
@ -211,22 +209,6 @@ impl Tool for EditFileTool {
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
@ -1370,73 +1352,6 @@ mod tests {
assert_eq!(actual, expected);
}
#[test]
fn still_streaming_ui_text_with_path() {
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
#[test]
fn still_streaming_ui_text_with_null() {
let input = serde_json::Value::Null;
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);

View file

@ -297,6 +297,12 @@ impl From<String> for LanguageModelToolResultContent {
}
}
impl From<LanguageModelImage> for LanguageModelToolResultContent {
fn from(image: LanguageModelImage) -> Self {
Self::Image(image)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent {
Text(String),