Compare commits
26 commits
main
...
edit-file-
Author | SHA1 | Date | |
---|---|---|---|
![]() |
320312fa25 | ||
![]() |
ebc7df2c2e | ||
![]() |
4d5b22a583 | ||
![]() |
970b5fe06e | ||
![]() |
63b625236d | ||
![]() |
e06f54054a | ||
![]() |
8f390d9c6d | ||
![]() |
da5f2978fd | ||
![]() |
294109c6da | ||
![]() |
d52e0f47b5 | ||
![]() |
26cf0cd9df | ||
![]() |
a7bcc0f97a | ||
![]() |
dbba5c8967 | ||
![]() |
5bfe086bf4 | ||
![]() |
4d94b9c2c2 | ||
![]() |
168e55db53 | ||
![]() |
160e6d5747 | ||
![]() |
6769b650d0 | ||
![]() |
c943f5a847 | ||
![]() |
afb9554a28 | ||
![]() |
c050527d90 | ||
![]() |
0d24686a9c | ||
![]() |
342247f60f | ||
![]() |
bf5c097732 | ||
![]() |
6dda5b9d86 | ||
![]() |
f0bf2e79d6 |
20 changed files with 2124 additions and 426 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
@ -158,8 +158,10 @@ dependencies = [
|
|||
"acp_thread",
|
||||
"agent-client-protocol",
|
||||
"agent_servers",
|
||||
"agent_settings",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"client",
|
||||
"clock",
|
||||
"cloud_llm_client",
|
||||
|
@ -177,6 +179,8 @@ dependencies = [
|
|||
"language_model",
|
||||
"language_models",
|
||||
"log",
|
||||
"lsp",
|
||||
"paths",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"prompt_store",
|
||||
|
|
|
@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
|||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = { version = "0.0.23" }
|
||||
agent-client-protocol = "0.0.23"
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
mod connection;
|
||||
mod diff;
|
||||
|
||||
pub use connection::*;
|
||||
pub use diff::*;
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result};
|
||||
use assistant_tool::ActionLog;
|
||||
use buffer_diff::BufferDiff;
|
||||
use editor::{Bias, MultiBuffer, PathKey};
|
||||
use editor::Bias;
|
||||
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
|
||||
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
|
||||
use itertools::Itertools;
|
||||
use language::{
|
||||
Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
|
||||
text_diff,
|
||||
};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
|
||||
use markdown::Markdown;
|
||||
use project::{AgentLocation, Project};
|
||||
use std::collections::HashMap;
|
||||
|
@ -140,7 +139,7 @@ impl AgentThreadEntry {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
|
||||
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
|
||||
if let AgentThreadEntry::ToolCall(call) = self {
|
||||
itertools::Either::Left(call.diffs())
|
||||
} else {
|
||||
|
@ -249,7 +248,7 @@ impl ToolCall {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
|
||||
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
|
||||
self.content.iter().filter_map(|content| match content {
|
||||
ToolCallContent::ContentBlock { .. } => None,
|
||||
ToolCallContent::Diff { diff } => Some(diff),
|
||||
|
@ -389,7 +388,7 @@ impl ContentBlock {
|
|||
#[derive(Debug)]
|
||||
pub enum ToolCallContent {
|
||||
ContentBlock { content: ContentBlock },
|
||||
Diff { diff: Diff },
|
||||
Diff { diff: Entity<Diff> },
|
||||
}
|
||||
|
||||
impl ToolCallContent {
|
||||
|
@ -403,7 +402,7 @@ impl ToolCallContent {
|
|||
content: ContentBlock::new(content, &language_registry, cx),
|
||||
},
|
||||
acp::ToolCallContent::Diff { diff } => Self::Diff {
|
||||
diff: Diff::from_acp(diff, language_registry, cx),
|
||||
diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -411,108 +410,11 @@ impl ToolCallContent {
|
|||
pub fn to_markdown(&self, cx: &App) -> String {
|
||||
match self {
|
||||
Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
|
||||
Self::Diff { diff } => diff.to_markdown(cx),
|
||||
Self::Diff { diff } => diff.read(cx).to_markdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Diff {
|
||||
pub multibuffer: Entity<MultiBuffer>,
|
||||
pub path: PathBuf,
|
||||
_task: Task<Result<()>>,
|
||||
}
|
||||
|
||||
impl Diff {
|
||||
pub fn from_acp(
|
||||
diff: acp::Diff,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let acp::Diff {
|
||||
path,
|
||||
old_text,
|
||||
new_text,
|
||||
} = diff;
|
||||
|
||||
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||
|
||||
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
||||
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
|
||||
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
|
||||
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
|
||||
|
||||
let task = cx.spawn({
|
||||
let multibuffer = multibuffer.clone();
|
||||
let path = path.clone();
|
||||
async move |cx| {
|
||||
let language = language_registry
|
||||
.language_for_file_path(&path)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
|
||||
|
||||
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_language(language, cx);
|
||||
buffer.snapshot()
|
||||
})?;
|
||||
|
||||
buffer_diff
|
||||
.update(cx, |diff, cx| {
|
||||
diff.set_base_text(
|
||||
old_buffer_snapshot,
|
||||
Some(language_registry),
|
||||
new_buffer_snapshot,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
multibuffer
|
||||
.update(cx, |multibuffer, cx| {
|
||||
let hunk_ranges = {
|
||||
let buffer = new_buffer.read(cx);
|
||||
let diff = buffer_diff.read(cx);
|
||||
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
|
||||
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
multibuffer.set_excerpts_for_path(
|
||||
PathKey::for_buffer(&new_buffer, cx),
|
||||
new_buffer.clone(),
|
||||
hunk_ranges,
|
||||
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
||||
cx,
|
||||
);
|
||||
multibuffer.add_diff(buffer_diff, cx);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
multibuffer,
|
||||
path,
|
||||
_task: task,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_markdown(&self, cx: &App) -> String {
|
||||
let buffer_text = self
|
||||
.multibuffer
|
||||
.read(cx)
|
||||
.all_buffers()
|
||||
.iter()
|
||||
.map(|buffer| buffer.read(cx).text())
|
||||
.join("\n");
|
||||
format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Plan {
|
||||
pub entries: Vec<PlanEntry>,
|
||||
|
@ -823,6 +725,21 @@ impl AcpThread {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_tool_call_diff(
|
||||
&mut self,
|
||||
tool_call_id: &acp::ToolCallId,
|
||||
diff: Entity<Diff>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
let (ix, current_call) = self
|
||||
.tool_call_mut(tool_call_id)
|
||||
.context("Tool call not found")?;
|
||||
current_call.content.clear();
|
||||
current_call.content.push(ToolCallContent::Diff { diff });
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
|
||||
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
|
||||
let status = ToolCallStatus::Allowed {
|
||||
|
|
388
crates/acp_thread/src/diff.rs
Normal file
388
crates/acp_thread/src/diff.rs
Normal file
|
@ -0,0 +1,388 @@
|
|||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use editor::{MultiBuffer, PathKey};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task};
|
||||
use itertools::Itertools;
|
||||
use language::{
|
||||
Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _, Point, Rope, TextBuffer,
|
||||
};
|
||||
use std::{
|
||||
cmp::Reverse,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
pub enum Diff {
|
||||
Pending(PendingDiff),
|
||||
Finalized(FinalizedDiff),
|
||||
}
|
||||
|
||||
impl Diff {
|
||||
pub fn from_acp(
|
||||
diff: acp::Diff,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let acp::Diff {
|
||||
path,
|
||||
old_text,
|
||||
new_text,
|
||||
} = diff;
|
||||
|
||||
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||
|
||||
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
||||
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
|
||||
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
|
||||
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
|
||||
|
||||
let task = cx.spawn({
|
||||
let multibuffer = multibuffer.clone();
|
||||
let path = path.clone();
|
||||
async move |_, cx| {
|
||||
let language = language_registry
|
||||
.language_for_file_path(&path)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
|
||||
|
||||
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_language(language, cx);
|
||||
buffer.snapshot()
|
||||
})?;
|
||||
|
||||
buffer_diff
|
||||
.update(cx, |diff, cx| {
|
||||
diff.set_base_text(
|
||||
old_buffer_snapshot,
|
||||
Some(language_registry),
|
||||
new_buffer_snapshot,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
multibuffer
|
||||
.update(cx, |multibuffer, cx| {
|
||||
let hunk_ranges = {
|
||||
let buffer = new_buffer.read(cx);
|
||||
let diff = buffer_diff.read(cx);
|
||||
diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
|
||||
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
multibuffer.set_excerpts_for_path(
|
||||
PathKey::for_buffer(&new_buffer, cx),
|
||||
new_buffer.clone(),
|
||||
hunk_ranges,
|
||||
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
||||
cx,
|
||||
);
|
||||
multibuffer.add_diff(buffer_diff, cx);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
Self::Finalized(FinalizedDiff {
|
||||
multibuffer,
|
||||
path,
|
||||
_update_diff: task,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(buffer: Entity<Buffer>, cx: &mut Context<Self>) -> Self {
|
||||
let buffer_snapshot = buffer.read(cx).snapshot();
|
||||
let base_text = buffer_snapshot.text();
|
||||
let language_registry = buffer.read(cx).language_registry();
|
||||
let text_snapshot = buffer.read(cx).text_snapshot();
|
||||
let buffer_diff = cx.new(|cx| {
|
||||
let mut diff = BufferDiff::new(&text_snapshot, cx);
|
||||
let _ = diff.set_base_text(
|
||||
buffer_snapshot.clone(),
|
||||
language_registry,
|
||||
text_snapshot,
|
||||
cx,
|
||||
);
|
||||
diff
|
||||
});
|
||||
|
||||
let multibuffer = cx.new(|cx| {
|
||||
let mut multibuffer = MultiBuffer::without_headers(Capability::ReadOnly);
|
||||
multibuffer.add_diff(buffer_diff.clone(), cx);
|
||||
multibuffer
|
||||
});
|
||||
|
||||
Self::Pending(PendingDiff {
|
||||
multibuffer,
|
||||
base_text: Arc::new(base_text),
|
||||
_subscription: cx.observe(&buffer, |this, _, cx| {
|
||||
if let Diff::Pending(diff) = this {
|
||||
diff.update(cx);
|
||||
}
|
||||
}),
|
||||
buffer,
|
||||
diff: buffer_diff,
|
||||
revealed_ranges: Vec::new(),
|
||||
update_diff: Task::ready(Ok(())),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Self>) {
|
||||
if let Self::Pending(diff) = self {
|
||||
diff.reveal_range(range, cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn finalize(&mut self, cx: &mut Context<Self>) {
|
||||
if let Self::Pending(diff) = self {
|
||||
*self = Self::Finalized(diff.finalize(cx));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multibuffer(&self) -> &Entity<MultiBuffer> {
|
||||
match self {
|
||||
Self::Pending(PendingDiff { multibuffer, .. }) => multibuffer,
|
||||
Self::Finalized(FinalizedDiff { multibuffer, .. }) => multibuffer,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self, cx: &App) -> String {
|
||||
let buffer_text = self
|
||||
.multibuffer()
|
||||
.read(cx)
|
||||
.all_buffers()
|
||||
.iter()
|
||||
.map(|buffer| buffer.read(cx).text())
|
||||
.join("\n");
|
||||
let path = match self {
|
||||
Diff::Pending(PendingDiff { buffer, .. }) => {
|
||||
buffer.read(cx).file().map(|file| file.path().as_ref())
|
||||
}
|
||||
Diff::Finalized(FinalizedDiff { path, .. }) => Some(path.as_path()),
|
||||
};
|
||||
format!(
|
||||
"Diff: {}\n```\n{}\n```\n",
|
||||
path.unwrap_or(Path::new("untitled")).display(),
|
||||
buffer_text
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PendingDiff {
|
||||
multibuffer: Entity<MultiBuffer>,
|
||||
base_text: Arc<String>,
|
||||
buffer: Entity<Buffer>,
|
||||
diff: Entity<BufferDiff>,
|
||||
revealed_ranges: Vec<Range<Anchor>>,
|
||||
_subscription: Subscription,
|
||||
update_diff: Task<Result<()>>,
|
||||
}
|
||||
|
||||
impl PendingDiff {
|
||||
pub fn update(&mut self, cx: &mut Context<Diff>) {
|
||||
let buffer = self.buffer.clone();
|
||||
let buffer_diff = self.diff.clone();
|
||||
let base_text = self.base_text.clone();
|
||||
self.update_diff = cx.spawn(async move |diff, cx| {
|
||||
let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?;
|
||||
let diff_snapshot = BufferDiff::update_diff(
|
||||
buffer_diff.clone(),
|
||||
text_snapshot.clone(),
|
||||
Some(base_text),
|
||||
false,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
buffer_diff.update(cx, |diff, cx| {
|
||||
diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
|
||||
})?;
|
||||
diff.update(cx, |diff, cx| {
|
||||
if let Diff::Pending(diff) = diff {
|
||||
diff.update_visible_ranges(cx);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Diff>) {
|
||||
self.revealed_ranges.push(range);
|
||||
self.update_visible_ranges(cx);
|
||||
}
|
||||
|
||||
fn finalize(&self, cx: &mut Context<Diff>) -> FinalizedDiff {
|
||||
let ranges = self.excerpt_ranges(cx);
|
||||
let base_text = self.base_text.clone();
|
||||
let language_registry = self.buffer.read(cx).language_registry().clone();
|
||||
|
||||
let path = self
|
||||
.buffer
|
||||
.read(cx)
|
||||
.file()
|
||||
.map(|file| file.path().as_ref())
|
||||
.unwrap_or(Path::new("untitled"))
|
||||
.into();
|
||||
|
||||
// Replace the buffer in the multibuffer with the snapshot
|
||||
let buffer = cx.new(|cx| {
|
||||
let language = self.buffer.read(cx).language().cloned();
|
||||
let buffer = TextBuffer::new_normalized(
|
||||
0,
|
||||
cx.entity_id().as_non_zero_u64().into(),
|
||||
self.buffer.read(cx).line_ending(),
|
||||
self.buffer.read(cx).as_rope().clone(),
|
||||
);
|
||||
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
|
||||
buffer.set_language(language, cx);
|
||||
buffer
|
||||
});
|
||||
|
||||
let buffer_diff = cx.spawn({
|
||||
let buffer = buffer.clone();
|
||||
let language_registry = language_registry.clone();
|
||||
async move |_this, cx| {
|
||||
build_buffer_diff(base_text, &buffer, language_registry, cx).await
|
||||
}
|
||||
});
|
||||
|
||||
let update_diff = cx.spawn(async move |this, cx| {
|
||||
let buffer_diff = buffer_diff.await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.multibuffer().update(cx, |multibuffer, cx| {
|
||||
let path_key = PathKey::for_buffer(&buffer, cx);
|
||||
multibuffer.clear(cx);
|
||||
multibuffer.set_excerpts_for_path(
|
||||
path_key,
|
||||
buffer,
|
||||
ranges,
|
||||
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
||||
cx,
|
||||
);
|
||||
multibuffer.add_diff(buffer_diff.clone(), cx);
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
})
|
||||
});
|
||||
|
||||
FinalizedDiff {
|
||||
path,
|
||||
multibuffer: self.multibuffer.clone(),
|
||||
_update_diff: update_diff,
|
||||
}
|
||||
}
|
||||
|
||||
fn update_visible_ranges(&mut self, cx: &mut Context<Diff>) {
|
||||
let ranges = self.excerpt_ranges(cx);
|
||||
self.multibuffer.update(cx, |multibuffer, cx| {
|
||||
multibuffer.set_excerpts_for_path(
|
||||
PathKey::for_buffer(&self.buffer, cx),
|
||||
self.buffer.clone(),
|
||||
ranges,
|
||||
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
||||
cx,
|
||||
);
|
||||
let end = multibuffer.len(cx);
|
||||
Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn excerpt_ranges(&self, cx: &App) -> Vec<Range<Point>> {
|
||||
let buffer = self.buffer.read(cx);
|
||||
let diff = self.diff.read(cx);
|
||||
let mut ranges = diff
|
||||
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
|
||||
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
|
||||
.collect::<Vec<_>>();
|
||||
ranges.extend(
|
||||
self.revealed_ranges
|
||||
.iter()
|
||||
.map(|range| range.to_point(&buffer)),
|
||||
);
|
||||
ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
|
||||
|
||||
// Merge adjacent ranges
|
||||
let mut ranges = ranges.into_iter().peekable();
|
||||
let mut merged_ranges = Vec::new();
|
||||
while let Some(mut range) = ranges.next() {
|
||||
while let Some(next_range) = ranges.peek() {
|
||||
if range.end >= next_range.start {
|
||||
range.end = range.end.max(next_range.end);
|
||||
ranges.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
merged_ranges.push(range);
|
||||
}
|
||||
merged_ranges
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FinalizedDiff {
|
||||
path: PathBuf,
|
||||
multibuffer: Entity<MultiBuffer>,
|
||||
_update_diff: Task<Result<()>>,
|
||||
}
|
||||
|
||||
async fn build_buffer_diff(
|
||||
old_text: Arc<String>,
|
||||
buffer: &Entity<Buffer>,
|
||||
language_registry: Option<Arc<LanguageRegistry>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<BufferDiff>> {
|
||||
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
|
||||
|
||||
let old_text_rope = cx
|
||||
.background_spawn({
|
||||
let old_text = old_text.clone();
|
||||
async move { Rope::from(old_text.as_str()) }
|
||||
})
|
||||
.await;
|
||||
let base_buffer = cx
|
||||
.update(|cx| {
|
||||
Buffer::build_snapshot(
|
||||
old_text_rope,
|
||||
buffer.language().cloned(),
|
||||
language_registry,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let diff_snapshot = cx
|
||||
.update(|cx| {
|
||||
BufferDiffSnapshot::new_with_base_buffer(
|
||||
buffer.text.clone(),
|
||||
Some(old_text),
|
||||
base_buffer,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let secondary_diff = cx.new(|cx| {
|
||||
let mut diff = BufferDiff::new(&buffer, cx);
|
||||
diff.set_snapshot(diff_snapshot.clone(), &buffer, cx);
|
||||
diff
|
||||
})?;
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut diff = BufferDiff::new(&buffer.text, cx);
|
||||
diff.set_snapshot(diff_snapshot, &buffer, cx);
|
||||
diff.set_secondary_diff(secondary_diff);
|
||||
diff
|
||||
})
|
||||
}
|
|
@ -15,8 +15,10 @@ workspace = true
|
|||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_servers.workspace = true
|
||||
agent_settings.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
|
@ -29,6 +31,7 @@ language.workspace = true
|
|||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
rust-embed.workspace = true
|
||||
|
@ -53,6 +56,7 @@ gpui = { workspace = true, "features" = ["test-support"] }
|
|||
gpui_tokio.workspace = true
|
||||
language = { workspace = true, "features" = ["test-support"] }
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
lsp = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{templates::Templates, AgentResponseEvent, Thread};
|
||||
use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
|
||||
use crate::{EditFileTool, FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
|
||||
use acp_thread::ModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
|
@ -412,11 +412,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||
})?;
|
||||
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
||||
thread
|
||||
});
|
||||
|
||||
|
@ -564,6 +565,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::ToolCallDiff(tool_call_diff) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.set_tool_call_diff(
|
||||
&tool_call_diff.tool_call_id,
|
||||
tool_call_diff.diff,
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
|
|
|
@ -9,5 +9,6 @@ mod tests;
|
|||
|
||||
pub use agent::*;
|
||||
pub use native_agent_server::NativeAgentServer;
|
||||
pub use templates::*;
|
||||
pub use thread::*;
|
||||
pub use tools::*;
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use super::*;
|
||||
use crate::templates::Templates;
|
||||
use acp_thread::AgentConnection;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
|
@ -273,7 +272,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
|||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: None
|
||||
output: Some("Allowed".into())
|
||||
}),
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
|
||||
|
@ -648,6 +647,19 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
|
||||
cx.run_until_parked();
|
||||
|
||||
// Simulate streaming partial input.
|
||||
let input = json!({});
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "1".into(),
|
||||
name: ThinkingTool.name().into(),
|
||||
raw_input: input.to_string(),
|
||||
input,
|
||||
is_input_complete: false,
|
||||
},
|
||||
));
|
||||
|
||||
// Input streaming completed
|
||||
let input = json!({ "content": "Thinking hard!" });
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
|
@ -666,12 +678,12 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
tool_call,
|
||||
acp::ToolCall {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
title: "Thinking".into(),
|
||||
title: "thinking".into(),
|
||||
kind: acp::ToolKind::Think,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(json!({ "content": "Thinking hard!" })),
|
||||
raw_input: Some(json!({})),
|
||||
raw_output: None,
|
||||
}
|
||||
);
|
||||
|
@ -681,7 +693,20 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
|||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress,),
|
||||
title: Some("Thinking".into()),
|
||||
kind: Some(acp::ToolKind::Think),
|
||||
raw_input: Some(json!({ "content": "Thinking hard!" })),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
);
|
||||
let update = expect_tool_call_update(&mut events).await;
|
||||
assert_eq!(
|
||||
update,
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ pub struct EchoTool;
|
|||
|
||||
impl AgentTool for EchoTool {
|
||||
type Input = EchoToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"echo".into()
|
||||
|
@ -48,6 +49,7 @@ pub struct DelayTool;
|
|||
|
||||
impl AgentTool for DelayTool {
|
||||
type Input = DelayToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"delay".into()
|
||||
|
@ -84,6 +86,7 @@ pub struct ToolRequiringPermission;
|
|||
|
||||
impl AgentTool for ToolRequiringPermission {
|
||||
type Input = ToolRequiringPermissionInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"tool_requiring_permission".into()
|
||||
|
@ -99,14 +102,11 @@ impl AgentTool for ToolRequiringPermission {
|
|||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let auth_check = self.authorize(input, event_stream);
|
||||
) -> Task<Result<String>> {
|
||||
let auth_check = event_stream.authorize("Authorize?".into());
|
||||
cx.foreground_executor().spawn(async move {
|
||||
auth_check.await?;
|
||||
Ok("Allowed".to_string())
|
||||
|
@ -121,6 +121,7 @@ pub struct InfiniteTool;
|
|||
|
||||
impl AgentTool for InfiniteTool {
|
||||
type Input = InfiniteToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"infinite".into()
|
||||
|
@ -171,19 +172,20 @@ pub struct WordListTool;
|
|||
|
||||
impl AgentTool for WordListTool {
|
||||
type Input = WordListInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"word_list".into()
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
"List of random words".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
"List of random words".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: Self::Input,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::templates::{SystemPromptTemplate, Template, Templates};
|
||||
use crate::{SystemPromptTemplate, Template, Templates};
|
||||
use acp_thread::Diff;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_tool::{adapt_schema_to_format, ActionLog};
|
||||
|
@ -103,6 +104,7 @@ pub enum AgentResponseEvent {
|
|||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
ToolCallDiff(ToolCallDiff),
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
|
@ -113,6 +115,12 @@ pub struct ToolCallAuthorization {
|
|||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCallDiff {
|
||||
pub tool_call_id: acp::ToolCallId,
|
||||
pub diff: Entity<acp_thread::Diff>,
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
completion_mode: CompletionMode,
|
||||
|
@ -125,12 +133,13 @@ pub struct Thread {
|
|||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
_project: Entity<Project>,
|
||||
project: Entity<Project>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
|
@ -145,10 +154,19 @@ impl Thread {
|
|||
project_context,
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn project(&self) -> &Entity<Project> {
|
||||
&self.project
|
||||
}
|
||||
|
||||
pub fn action_log(&self) -> &Entity<ActionLog> {
|
||||
&self.action_log
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
@ -315,10 +333,6 @@ impl Thread {
|
|||
events_rx
|
||||
}
|
||||
|
||||
pub fn action_log(&self) -> &Entity<ActionLog> {
|
||||
&self.action_log
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self) -> AgentMessage {
|
||||
log::debug!("Building system message");
|
||||
let prompt = SystemPromptTemplate {
|
||||
|
@ -460,8 +474,17 @@ impl Thread {
|
|||
}
|
||||
});
|
||||
|
||||
let mut title = SharedString::from(&tool_use.name);
|
||||
let mut kind = acp::ToolKind::Other;
|
||||
if let Some(tool) = tool.as_ref() {
|
||||
if let Ok(initial_title) = tool.initial_title(tool_use.input.clone()) {
|
||||
title = initial_title;
|
||||
}
|
||||
kind = tool.kind();
|
||||
}
|
||||
|
||||
if push_new_tool_use {
|
||||
event_stream.send_tool_call(tool.as_ref(), &tool_use);
|
||||
event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
|
@ -469,6 +492,8 @@ impl Thread {
|
|||
event_stream.send_tool_call_update(
|
||||
&tool_use.id,
|
||||
acp::ToolCallUpdateFields {
|
||||
title: Some(title.into()),
|
||||
kind: Some(kind),
|
||||
raw_input: Some(tool_use.input.clone()),
|
||||
..Default::default()
|
||||
},
|
||||
|
@ -490,15 +515,33 @@ impl Thread {
|
|||
}));
|
||||
};
|
||||
|
||||
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
|
||||
let tool_event_stream =
|
||||
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
let supports_images = self.selected_model.supports_images();
|
||||
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
let tool_result = tool_result.await.and_then(|output| {
|
||||
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
|
||||
if !supports_images {
|
||||
return Err(anyhow!(
|
||||
"Attempted to read an image, but this model doesn't support it.",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(output)
|
||||
});
|
||||
|
||||
match tool_result {
|
||||
Ok(output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
content: output.llm_output,
|
||||
output: Some(output.raw_output),
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
|
@ -511,24 +554,6 @@ impl Thread {
|
|||
}))
|
||||
}
|
||||
|
||||
fn run_tool(
|
||||
&self,
|
||||
tool: Arc<dyn AnyAgentTool>,
|
||||
tool_use: LanguageModelToolUse,
|
||||
event_stream: AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<String>> {
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_tool_use_json_parse_error_event(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
|
@ -572,7 +597,7 @@ impl Thread {
|
|||
self.messages.last_mut().unwrap()
|
||||
}
|
||||
|
||||
fn build_completion_request(
|
||||
pub(crate) fn build_completion_request(
|
||||
&self,
|
||||
completion_intent: CompletionIntent,
|
||||
cx: &mut App,
|
||||
|
@ -662,6 +687,7 @@ where
|
|||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
|
||||
type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
|
||||
|
@ -685,23 +711,13 @@ where
|
|||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Allows the tool to authorize a given tool call with the user if necessary
|
||||
fn authorize(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
) -> impl use<Self> + Future<Output = Result<()>> {
|
||||
let json_input = serde_json::json!(&input);
|
||||
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
) -> Task<Result<Self::Output>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
|
@ -710,6 +726,11 @@ where
|
|||
|
||||
pub struct Erased<T>(T);
|
||||
|
||||
pub struct AgentToolOutput {
|
||||
llm_output: LanguageModelToolResultContent,
|
||||
raw_output: serde_json::Value,
|
||||
}
|
||||
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
|
@ -721,7 +742,7 @@ pub trait AnyAgentTool {
|
|||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
) -> Task<Result<AgentToolOutput>>;
|
||||
}
|
||||
|
||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||
|
@ -756,12 +777,18 @@ where
|
|||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, event_stream, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
) -> Task<Result<AgentToolOutput>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let input = serde_json::from_value(input)?;
|
||||
let output = cx
|
||||
.update(|cx| self.0.clone().run(input, event_stream, cx))?
|
||||
.await?;
|
||||
let raw_output = serde_json::to_value(&output)?;
|
||||
Ok(AgentToolOutput {
|
||||
llm_output: output.into(),
|
||||
raw_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -826,17 +853,17 @@ impl AgentResponseEventStream {
|
|||
|
||||
fn send_tool_call(
|
||||
&self,
|
||||
tool: Option<&Arc<dyn AnyAgentTool>>,
|
||||
tool_use: &LanguageModelToolUse,
|
||||
id: &LanguageModelToolUseId,
|
||||
title: SharedString,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
|
||||
&tool_use.id,
|
||||
tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
|
||||
.map(|i| i.into())
|
||||
.unwrap_or_else(|| tool_use.name.to_string()),
|
||||
tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
|
||||
tool_use.input.clone(),
|
||||
id,
|
||||
title.to_string(),
|
||||
kind,
|
||||
input,
|
||||
))))
|
||||
.ok();
|
||||
}
|
||||
|
@ -874,6 +901,12 @@ impl AgentResponseEventStream {
|
|||
.ok();
|
||||
}
|
||||
|
||||
fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: StopReason) {
|
||||
match reason {
|
||||
StopReason::EndTurn => {
|
||||
|
@ -903,13 +936,41 @@ impl AgentResponseEventStream {
|
|||
#[derive(Clone)]
|
||||
pub struct ToolCallEventStream {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
stream: AgentResponseEventStream,
|
||||
}
|
||||
|
||||
impl ToolCallEventStream {
|
||||
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
|
||||
#[cfg(test)]
|
||||
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let stream = ToolCallEventStream::new(
|
||||
&LanguageModelToolUse {
|
||||
id: "test_id".into(),
|
||||
name: "test_tool".into(),
|
||||
raw_input: String::new(),
|
||||
input: serde_json::Value::Null,
|
||||
is_input_complete: true,
|
||||
},
|
||||
acp::ToolKind::Other,
|
||||
AgentResponseEventStream(events_tx),
|
||||
);
|
||||
|
||||
(stream, ToolCallEventStreamReceiver(events_rx))
|
||||
}
|
||||
|
||||
fn new(
|
||||
tool_use: &LanguageModelToolUse,
|
||||
kind: acp::ToolKind,
|
||||
stream: AgentResponseEventStream,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_use_id,
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
kind,
|
||||
input: tool_use.input.clone(),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
|
@ -918,38 +979,52 @@ impl ToolCallEventStream {
|
|||
self.stream.send_tool_call_update(&self.tool_use_id, fields);
|
||||
}
|
||||
|
||||
pub fn authorize(
|
||||
&self,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.stream
|
||||
.authorize_tool_call(&self.tool_use_id, title, kind, input)
|
||||
pub fn send_diff(&self, diff: Entity<Diff>) {
|
||||
self.stream.send_tool_call_diff(ToolCallDiff {
|
||||
tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
||||
diff,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.stream.authorize_tool_call(
|
||||
&self.tool_use_id,
|
||||
title,
|
||||
self.kind.clone(),
|
||||
self.input.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub struct TestToolCallEventStream {
|
||||
stream: ToolCallEventStream,
|
||||
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
}
|
||||
pub struct ToolCallEventStreamReceiver(
|
||||
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
impl TestToolCallEventStream {
|
||||
pub fn new() -> Self {
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
|
||||
|
||||
Self {
|
||||
stream,
|
||||
_events_rx: events_rx,
|
||||
impl ToolCallEventStreamReceiver {
|
||||
pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
|
||||
let event = self.0.next().await;
|
||||
if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
|
||||
auth
|
||||
} else {
|
||||
panic!("Expected ToolCallAuthorization but got: {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream(&self) -> ToolCallEventStream {
|
||||
self.stream.clone()
|
||||
#[cfg(test)]
|
||||
impl std::ops::Deref for ToolCallEventStreamReceiver {
|
||||
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl std::ops::DerefMut for ToolCallEventStreamReceiver {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
mod edit_file_tool;
|
||||
mod find_path_tool;
|
||||
mod read_file_tool;
|
||||
mod thinking_tool;
|
||||
|
||||
pub use edit_file_tool::*;
|
||||
pub use find_path_tool::*;
|
||||
pub use read_file_tool::*;
|
||||
pub use thinking_tool::*;
|
||||
|
|
1361
crates/agent2/src/tools/edit_file_tool.rs
Normal file
1361
crates/agent2/src/tools/edit_file_tool.rs
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,8 @@
|
|||
use crate::{AgentTool, ToolCallEventStream};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -8,8 +10,6 @@ use std::fmt::Write;
|
|||
use std::{cmp, path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
|
||||
/// Fast file path pattern matching tool that works with any codebase size
|
||||
///
|
||||
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
|
||||
|
@ -39,8 +39,35 @@ pub struct FindPathToolInput {
|
|||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FindPathToolOutput {
|
||||
paths: Vec<PathBuf>,
|
||||
pub struct FindPathToolOutput {
|
||||
offset: usize,
|
||||
current_matches_page: Vec<PathBuf>,
|
||||
all_matches_len: usize,
|
||||
}
|
||||
|
||||
impl From<FindPathToolOutput> for LanguageModelToolResultContent {
|
||||
fn from(output: FindPathToolOutput) -> Self {
|
||||
if output.current_matches_page.is_empty() {
|
||||
"No matches found".into()
|
||||
} else {
|
||||
let mut llm_output = format!("Found {} total matches.", output.all_matches_len);
|
||||
if output.all_matches_len > RESULTS_PER_PAGE {
|
||||
write!(
|
||||
&mut llm_output,
|
||||
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
|
||||
output.offset + 1,
|
||||
output.offset + output.current_matches_page.len()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for mat in output.current_matches_page {
|
||||
write!(&mut llm_output, "\n{}", mat.display()).unwrap();
|
||||
}
|
||||
|
||||
llm_output.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const RESULTS_PER_PAGE: usize = 50;
|
||||
|
@ -57,6 +84,7 @@ impl FindPathTool {
|
|||
|
||||
impl AgentTool for FindPathTool {
|
||||
type Input = FindPathToolInput;
|
||||
type Output = FindPathToolOutput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"find_path".into()
|
||||
|
@ -75,7 +103,7 @@ impl AgentTool for FindPathTool {
|
|||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
) -> Task<Result<FindPathToolOutput>> {
|
||||
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
|
@ -113,26 +141,11 @@ impl AgentTool for FindPathTool {
|
|||
..Default::default()
|
||||
});
|
||||
|
||||
if matches.is_empty() {
|
||||
Ok("No matches found".into())
|
||||
} else {
|
||||
let mut message = format!("Found {} total matches.", matches.len());
|
||||
if matches.len() > RESULTS_PER_PAGE {
|
||||
write!(
|
||||
&mut message,
|
||||
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
|
||||
input.offset + 1,
|
||||
input.offset + paginated_matches.len()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
|
||||
write!(&mut message, "\n{}", mat.display()).unwrap();
|
||||
}
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
Ok(FindPathToolOutput {
|
||||
offset: input.offset,
|
||||
current_matches_page: paginated_matches.to_vec(),
|
||||
all_matches_len: matches.len(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use assistant_tool::{outline, ActionLog};
|
||||
use gpui::{Entity, Task};
|
||||
use indoc::formatdoc;
|
||||
use language::{Anchor, Point};
|
||||
use project::{AgentLocation, Project, WorktreeSettings};
|
||||
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
|
||||
use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
|
@ -59,6 +60,7 @@ impl ReadFileTool {
|
|||
|
||||
impl AgentTool for ReadFileTool {
|
||||
type Input = ReadFileToolInput;
|
||||
type Output = LanguageModelToolResultContent;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"read_file".into()
|
||||
|
@ -91,9 +93,9 @@ impl AgentTool for ReadFileTool {
|
|||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
) -> Task<Result<LanguageModelToolResultContent>> {
|
||||
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
|
||||
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
|
||||
};
|
||||
|
@ -132,51 +134,27 @@ impl AgentTool for ReadFileTool {
|
|||
|
||||
let file_path = input.path.clone();
|
||||
|
||||
event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
locations: Some(vec![acp::ToolCallLocation {
|
||||
path: project_path.path.to_path_buf(),
|
||||
line: input.start_line,
|
||||
// TODO (tracked): use full range
|
||||
}]),
|
||||
..Default::default()
|
||||
});
|
||||
if image_store::is_image_file(&self.project, &project_path, cx) {
|
||||
return cx.spawn(async move |cx| {
|
||||
let image_entity: Entity<ImageItem> = cx
|
||||
.update(|cx| {
|
||||
self.project.update(cx, |project, cx| {
|
||||
project.open_image(project_path.clone(), cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
// TODO (tracked): images
|
||||
// if image_store::is_image_file(&self.project, &project_path, cx) {
|
||||
// let model = &self.thread.read(cx).selected_model;
|
||||
let image =
|
||||
image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
|
||||
|
||||
// if !model.supports_images() {
|
||||
// return Task::ready(Err(anyhow!(
|
||||
// "Attempted to read an image, but Zed doesn't currently support sending images to {}.",
|
||||
// model.name().0
|
||||
// )))
|
||||
// .into();
|
||||
// }
|
||||
let language_model_image = cx
|
||||
.update(|cx| LanguageModelImage::from_image(image, cx))?
|
||||
.await
|
||||
.context("processing image")?;
|
||||
|
||||
// return cx.spawn(async move |cx| -> Result<ToolResultOutput> {
|
||||
// let image_entity: Entity<ImageItem> = cx
|
||||
// .update(|cx| {
|
||||
// self.project.update(cx, |project, cx| {
|
||||
// project.open_image(project_path.clone(), cx)
|
||||
// })
|
||||
// })?
|
||||
// .await?;
|
||||
|
||||
// let image =
|
||||
// image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
|
||||
|
||||
// let language_model_image = cx
|
||||
// .update(|cx| LanguageModelImage::from_image(image, cx))?
|
||||
// .await
|
||||
// .context("processing image")?;
|
||||
|
||||
// Ok(ToolResultOutput {
|
||||
// content: ToolResultContent::Image(language_model_image),
|
||||
// output: None,
|
||||
// })
|
||||
// });
|
||||
// }
|
||||
//
|
||||
Ok(language_model_image.into())
|
||||
});
|
||||
}
|
||||
|
||||
let project = self.project.clone();
|
||||
let action_log = self.action_log.clone();
|
||||
|
@ -244,7 +222,7 @@ impl AgentTool for ReadFileTool {
|
|||
})?;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
Ok(result.into())
|
||||
} else {
|
||||
// No line ranges specified, so check file size to see if it's too big.
|
||||
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
|
||||
|
@ -257,7 +235,7 @@ impl AgentTool for ReadFileTool {
|
|||
log.buffer_read(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
Ok(result.into())
|
||||
} else {
|
||||
// File is too big, so return the outline
|
||||
// and a suggestion to read again with line numbers.
|
||||
|
@ -276,7 +254,8 @@ impl AgentTool for ReadFileTool {
|
|||
|
||||
Alternatively, you can fall back to the `grep` tool (if available)
|
||||
to search the file for specific content."
|
||||
})
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -285,8 +264,6 @@ impl AgentTool for ReadFileTool {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::TestToolCallEventStream;
|
||||
|
||||
use super::*;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
|
||||
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
|
||||
|
@ -304,7 +281,7 @@ mod test {
|
|||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let (event_stream, _) = ToolCallEventStream::test();
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
|
@ -313,7 +290,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, event_stream, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(
|
||||
|
@ -321,6 +298,7 @@ mod test {
|
|||
"root/nonexistent_file.txt not found"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_small_file(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
@ -336,7 +314,6 @@ mod test {
|
|||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
|
@ -344,10 +321,10 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "This is a small file content");
|
||||
assert_eq!(result.unwrap(), "This is a small file content".into());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -367,18 +344,18 @@ mod test {
|
|||
language_registry.add(Arc::new(rust_lang()));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let content = cx
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/large_file.rs".into(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let content = result.to_str().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
content.lines().skip(4).take(6).collect::<Vec<_>>(),
|
||||
|
@ -399,10 +376,11 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
let content = result.unwrap();
|
||||
.await
|
||||
.unwrap();
|
||||
let content = result.to_str().unwrap();
|
||||
let expected_content = (0..1000)
|
||||
.flat_map(|i| {
|
||||
vec![
|
||||
|
@ -438,7 +416,6 @@ mod test {
|
|||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
|
@ -446,10 +423,10 @@ mod test {
|
|||
start_line: Some(2),
|
||||
end_line: Some(4),
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
|
||||
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -467,7 +444,6 @@ mod test {
|
|||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
// start_line of 0 should be treated as 1
|
||||
let result = cx
|
||||
|
@ -477,10 +453,10 @@ mod test {
|
|||
start_line: Some(0),
|
||||
end_line: Some(2),
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 1\nLine 2");
|
||||
assert_eq!(result.unwrap(), "Line 1\nLine 2".into());
|
||||
|
||||
// end_line of 0 should result in at least 1 line
|
||||
let result = cx
|
||||
|
@ -490,10 +466,10 @@ mod test {
|
|||
start_line: Some(1),
|
||||
end_line: Some(0),
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 1");
|
||||
assert_eq!(result.unwrap(), "Line 1".into());
|
||||
|
||||
// when start_line > end_line, should still return at least 1 line
|
||||
let result = cx
|
||||
|
@ -503,10 +479,10 @@ mod test {
|
|||
start_line: Some(3),
|
||||
end_line: Some(2),
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 3");
|
||||
assert_eq!(result.unwrap(), "Line 3".into());
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
|
@ -612,7 +588,6 @@ mod test {
|
|||
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
// Reading a file outside the project worktree should fail
|
||||
let result = cx
|
||||
|
@ -622,7 +597,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -638,7 +613,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -654,7 +629,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -669,7 +644,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -685,7 +660,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -700,7 +675,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -715,7 +690,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -731,11 +706,11 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_ok(), "Should be able to read normal files");
|
||||
assert_eq!(result.unwrap(), "Normal file content");
|
||||
assert_eq!(result.unwrap(), "Normal file content".into());
|
||||
|
||||
// Path traversal attempts with .. should fail
|
||||
let result = cx
|
||||
|
@ -745,7 +720,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -826,7 +801,6 @@ mod test {
|
|||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
// Test reading allowed files in worktree1
|
||||
let result = cx
|
||||
|
@ -836,12 +810,15 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }");
|
||||
assert_eq!(
|
||||
result,
|
||||
"fn main() { println!(\"Hello from worktree1\"); }".into()
|
||||
);
|
||||
|
||||
// Test reading private file in worktree1 should fail
|
||||
let result = cx
|
||||
|
@ -851,7 +828,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
@ -872,7 +849,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
@ -893,14 +870,14 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
"export function greet() { return 'Hello from worktree2'; }"
|
||||
"export function greet() { return 'Hello from worktree2'; }".into()
|
||||
);
|
||||
|
||||
// Test reading private file in worktree2 should fail
|
||||
|
@ -911,7 +888,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
@ -932,7 +909,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
@ -954,7 +931,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ pub struct ThinkingTool;
|
|||
|
||||
impl AgentTool for ThinkingTool {
|
||||
type Input = ThinkingToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"thinking".into()
|
||||
|
|
|
@ -42,7 +42,7 @@ use workspace::{CollaboratorId, Workspace};
|
|||
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
|
||||
|
||||
use ::acp_thread::{
|
||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff,
|
||||
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
|
||||
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
|
||||
};
|
||||
|
||||
|
@ -732,7 +732,11 @@ impl AcpThreadView {
|
|||
cx: &App,
|
||||
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
|
||||
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
|
||||
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
||||
Some(
|
||||
entry
|
||||
.diffs()
|
||||
.map(|diff| diff.read(cx).multibuffer().clone()),
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
|
@ -1314,10 +1318,9 @@ impl AcpThreadView {
|
|||
Empty.into_any_element()
|
||||
}
|
||||
}
|
||||
ToolCallContent::Diff {
|
||||
diff: Diff { multibuffer, .. },
|
||||
..
|
||||
} => self.render_diff_editor(multibuffer),
|
||||
ToolCallContent::Diff { diff, .. } => {
|
||||
self.render_diff_editor(&diff.read(cx).multibuffer())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ mod copy_path_tool;
|
|||
mod create_directory_tool;
|
||||
mod delete_path_tool;
|
||||
mod diagnostics_tool;
|
||||
mod edit_agent;
|
||||
pub mod edit_agent;
|
||||
mod edit_file_tool;
|
||||
mod fetch_tool;
|
||||
mod find_path_tool;
|
||||
|
@ -14,7 +14,7 @@ mod open_tool;
|
|||
mod project_notifications_tool;
|
||||
mod read_file_tool;
|
||||
mod schema;
|
||||
mod templates;
|
||||
pub mod templates;
|
||||
mod terminal_tool;
|
||||
mod thinking_tool;
|
||||
mod ui;
|
||||
|
|
|
@ -29,7 +29,6 @@ use serde::{Deserialize, Serialize};
|
|||
use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
|
||||
use streaming_diff::{CharOperation, StreamingDiff};
|
||||
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
|
||||
use util::debug_panic;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct CreateFilePromptTemplate {
|
||||
|
@ -682,11 +681,6 @@ impl EditAgent {
|
|||
if last_message.content.is_empty() {
|
||||
conversation.messages.pop();
|
||||
}
|
||||
} else {
|
||||
debug_panic!(
|
||||
"Last message must be an Assistant tool calling! Got {:?}",
|
||||
last_message.content
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -120,8 +120,6 @@ struct PartialInput {
|
|||
display_description: String,
|
||||
}
|
||||
|
||||
const DEFAULT_UI_TEXT: &str = "Editing file";
|
||||
|
||||
impl Tool for EditFileTool {
|
||||
fn name(&self) -> String {
|
||||
"edit_file".into()
|
||||
|
@ -211,22 +209,6 @@ impl Tool for EditFileTool {
|
|||
}
|
||||
}
|
||||
|
||||
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
|
||||
let description = input.display_description.trim();
|
||||
if !description.is_empty() {
|
||||
return description.to_string();
|
||||
}
|
||||
|
||||
let path = input.path.trim();
|
||||
if !path.is_empty() {
|
||||
return path.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_UI_TEXT.to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
|
@ -1370,73 +1352,6 @@ mod tests {
|
|||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_path() {
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_description() {
|
||||
let input = json!({
|
||||
"path": "",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
EditFileTool.still_streaming_ui_text(&input),
|
||||
"Fix error handling",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_path_and_description() {
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
EditFileTool.still_streaming_ui_text(&input),
|
||||
"Fix error handling",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_no_path_or_description() {
|
||||
let input = json!({
|
||||
"path": "",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
EditFileTool.still_streaming_ui_text(&input),
|
||||
DEFAULT_UI_TEXT,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_null() {
|
||||
let input = serde_json::Value::Null;
|
||||
|
||||
assert_eq!(
|
||||
EditFileTool.still_streaming_ui_text(&input),
|
||||
DEFAULT_UI_TEXT,
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
|
|
|
@ -297,6 +297,12 @@ impl From<String> for LanguageModelToolResultContent {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelImage> for LanguageModelToolResultContent {
|
||||
fn from(image: LanguageModelImage) -> Self {
|
||||
Self::Image(image)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
|
||||
pub enum MessageContent {
|
||||
Text(String),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue