agent2: Port edit_file
tool (#35844)
TODO: - [x] Authorization - [x] Restore tests Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
d705585a2e
commit
2526dcb5a5
20 changed files with 2075 additions and 414 deletions
|
@ -15,8 +15,10 @@ workspace = true
|
|||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_servers.workspace = true
|
||||
agent_settings.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
|
@ -29,6 +31,7 @@ language.workspace = true
|
|||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
rust-embed.workspace = true
|
||||
|
@ -53,6 +56,7 @@ gpui = { workspace = true, "features" = ["test-support"] }
|
|||
gpui_tokio.workspace = true
|
||||
language = { workspace = true, "features" = ["test-support"] }
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
lsp = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{templates::Templates, AgentResponseEvent, Thread};
|
||||
use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
|
||||
use crate::{EditFileTool, FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
|
||||
use acp_thread::ModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
|
@ -412,11 +412,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||
})?;
|
||||
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
||||
thread
|
||||
});
|
||||
|
||||
|
@ -564,6 +565,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
|||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::ToolCallDiff(tool_call_diff) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.set_tool_call_diff(
|
||||
&tool_call_diff.tool_call_id,
|
||||
tool_call_diff.diff,
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
|
|
|
@ -9,5 +9,6 @@ mod tests;
|
|||
|
||||
pub use agent::*;
|
||||
pub use native_agent_server::NativeAgentServer;
|
||||
pub use templates::*;
|
||||
pub use thread::*;
|
||||
pub use tools::*;
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use super::*;
|
||||
use crate::templates::Templates;
|
||||
use acp_thread::AgentConnection;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
|
@ -273,7 +272,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
|||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: None
|
||||
output: Some("Allowed".into())
|
||||
}),
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
|
||||
|
|
|
@ -14,6 +14,7 @@ pub struct EchoTool;
|
|||
|
||||
impl AgentTool for EchoTool {
|
||||
type Input = EchoToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"echo".into()
|
||||
|
@ -48,6 +49,7 @@ pub struct DelayTool;
|
|||
|
||||
impl AgentTool for DelayTool {
|
||||
type Input = DelayToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"delay".into()
|
||||
|
@ -84,6 +86,7 @@ pub struct ToolRequiringPermission;
|
|||
|
||||
impl AgentTool for ToolRequiringPermission {
|
||||
type Input = ToolRequiringPermissionInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"tool_requiring_permission".into()
|
||||
|
@ -99,14 +102,11 @@ impl AgentTool for ToolRequiringPermission {
|
|||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let auth_check = self.authorize(input, event_stream);
|
||||
) -> Task<Result<String>> {
|
||||
let auth_check = event_stream.authorize("Authorize?".into());
|
||||
cx.foreground_executor().spawn(async move {
|
||||
auth_check.await?;
|
||||
Ok("Allowed".to_string())
|
||||
|
@ -121,6 +121,7 @@ pub struct InfiniteTool;
|
|||
|
||||
impl AgentTool for InfiniteTool {
|
||||
type Input = InfiniteToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"infinite".into()
|
||||
|
@ -171,19 +172,20 @@ pub struct WordListTool;
|
|||
|
||||
impl AgentTool for WordListTool {
|
||||
type Input = WordListInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"word_list".into()
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
"List of random words".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
"List of random words".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: Self::Input,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::templates::{SystemPromptTemplate, Template, Templates};
|
||||
use crate::{SystemPromptTemplate, Template, Templates};
|
||||
use acp_thread::Diff;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_tool::{adapt_schema_to_format, ActionLog};
|
||||
|
@ -103,6 +104,7 @@ pub enum AgentResponseEvent {
|
|||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
ToolCallDiff(ToolCallDiff),
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
|
@ -113,6 +115,12 @@ pub struct ToolCallAuthorization {
|
|||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCallDiff {
|
||||
pub tool_call_id: acp::ToolCallId,
|
||||
pub diff: Entity<acp_thread::Diff>,
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
completion_mode: CompletionMode,
|
||||
|
@ -125,12 +133,13 @@ pub struct Thread {
|
|||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
_project: Entity<Project>,
|
||||
project: Entity<Project>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
|
@ -145,10 +154,19 @@ impl Thread {
|
|||
project_context,
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn project(&self) -> &Entity<Project> {
|
||||
&self.project
|
||||
}
|
||||
|
||||
pub fn action_log(&self) -> &Entity<ActionLog> {
|
||||
&self.action_log
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
@ -315,10 +333,6 @@ impl Thread {
|
|||
events_rx
|
||||
}
|
||||
|
||||
pub fn action_log(&self) -> &Entity<ActionLog> {
|
||||
&self.action_log
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self) -> AgentMessage {
|
||||
log::debug!("Building system message");
|
||||
let prompt = SystemPromptTemplate {
|
||||
|
@ -490,15 +504,33 @@ impl Thread {
|
|||
}));
|
||||
};
|
||||
|
||||
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
|
||||
let tool_event_stream =
|
||||
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
let supports_images = self.selected_model.supports_images();
|
||||
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
let tool_result = tool_result.await.and_then(|output| {
|
||||
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
|
||||
if !supports_images {
|
||||
return Err(anyhow!(
|
||||
"Attempted to read an image, but this model doesn't support it.",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(output)
|
||||
});
|
||||
|
||||
match tool_result {
|
||||
Ok(output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
content: output.llm_output,
|
||||
output: Some(output.raw_output),
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
|
@ -511,24 +543,6 @@ impl Thread {
|
|||
}))
|
||||
}
|
||||
|
||||
fn run_tool(
|
||||
&self,
|
||||
tool: Arc<dyn AnyAgentTool>,
|
||||
tool_use: LanguageModelToolUse,
|
||||
event_stream: AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<String>> {
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_tool_use_json_parse_error_event(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
|
@ -572,7 +586,7 @@ impl Thread {
|
|||
self.messages.last_mut().unwrap()
|
||||
}
|
||||
|
||||
fn build_completion_request(
|
||||
pub(crate) fn build_completion_request(
|
||||
&self,
|
||||
completion_intent: CompletionIntent,
|
||||
cx: &mut App,
|
||||
|
@ -662,6 +676,7 @@ where
|
|||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
|
||||
type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
|
||||
|
@ -685,23 +700,13 @@ where
|
|||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Allows the tool to authorize a given tool call with the user if necessary
|
||||
fn authorize(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
) -> impl use<Self> + Future<Output = Result<()>> {
|
||||
let json_input = serde_json::json!(&input);
|
||||
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
) -> Task<Result<Self::Output>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
|
@ -710,6 +715,11 @@ where
|
|||
|
||||
pub struct Erased<T>(T);
|
||||
|
||||
pub struct AgentToolOutput {
|
||||
llm_output: LanguageModelToolResultContent,
|
||||
raw_output: serde_json::Value,
|
||||
}
|
||||
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
|
@ -721,7 +731,7 @@ pub trait AnyAgentTool {
|
|||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
) -> Task<Result<AgentToolOutput>>;
|
||||
}
|
||||
|
||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||
|
@ -756,12 +766,18 @@ where
|
|||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, event_stream, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
) -> Task<Result<AgentToolOutput>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let input = serde_json::from_value(input)?;
|
||||
let output = cx
|
||||
.update(|cx| self.0.clone().run(input, event_stream, cx))?
|
||||
.await?;
|
||||
let raw_output = serde_json::to_value(&output)?;
|
||||
Ok(AgentToolOutput {
|
||||
llm_output: output.into(),
|
||||
raw_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -874,6 +890,12 @@ impl AgentResponseEventStream {
|
|||
.ok();
|
||||
}
|
||||
|
||||
fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: StopReason) {
|
||||
match reason {
|
||||
StopReason::EndTurn => {
|
||||
|
@ -903,13 +925,41 @@ impl AgentResponseEventStream {
|
|||
#[derive(Clone)]
|
||||
pub struct ToolCallEventStream {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
stream: AgentResponseEventStream,
|
||||
}
|
||||
|
||||
impl ToolCallEventStream {
|
||||
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
|
||||
#[cfg(test)]
|
||||
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let stream = ToolCallEventStream::new(
|
||||
&LanguageModelToolUse {
|
||||
id: "test_id".into(),
|
||||
name: "test_tool".into(),
|
||||
raw_input: String::new(),
|
||||
input: serde_json::Value::Null,
|
||||
is_input_complete: true,
|
||||
},
|
||||
acp::ToolKind::Other,
|
||||
AgentResponseEventStream(events_tx),
|
||||
);
|
||||
|
||||
(stream, ToolCallEventStreamReceiver(events_rx))
|
||||
}
|
||||
|
||||
fn new(
|
||||
tool_use: &LanguageModelToolUse,
|
||||
kind: acp::ToolKind,
|
||||
stream: AgentResponseEventStream,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_use_id,
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
kind,
|
||||
input: tool_use.input.clone(),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
|
@ -918,38 +968,52 @@ impl ToolCallEventStream {
|
|||
self.stream.send_tool_call_update(&self.tool_use_id, fields);
|
||||
}
|
||||
|
||||
pub fn authorize(
|
||||
&self,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.stream
|
||||
.authorize_tool_call(&self.tool_use_id, title, kind, input)
|
||||
pub fn send_diff(&self, diff: Entity<Diff>) {
|
||||
self.stream.send_tool_call_diff(ToolCallDiff {
|
||||
tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
||||
diff,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.stream.authorize_tool_call(
|
||||
&self.tool_use_id,
|
||||
title,
|
||||
self.kind.clone(),
|
||||
self.input.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub struct TestToolCallEventStream {
|
||||
stream: ToolCallEventStream,
|
||||
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
}
|
||||
pub struct ToolCallEventStreamReceiver(
|
||||
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
impl TestToolCallEventStream {
|
||||
pub fn new() -> Self {
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
|
||||
|
||||
Self {
|
||||
stream,
|
||||
_events_rx: events_rx,
|
||||
impl ToolCallEventStreamReceiver {
|
||||
pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
|
||||
let event = self.0.next().await;
|
||||
if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
|
||||
auth
|
||||
} else {
|
||||
panic!("Expected ToolCallAuthorization but got: {:?}", event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream(&self) -> ToolCallEventStream {
|
||||
self.stream.clone()
|
||||
#[cfg(test)]
|
||||
impl std::ops::Deref for ToolCallEventStreamReceiver {
|
||||
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl std::ops::DerefMut for ToolCallEventStreamReceiver {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
mod edit_file_tool;
|
||||
mod find_path_tool;
|
||||
mod read_file_tool;
|
||||
mod thinking_tool;
|
||||
|
||||
pub use edit_file_tool::*;
|
||||
pub use find_path_tool::*;
|
||||
pub use read_file_tool::*;
|
||||
pub use thinking_tool::*;
|
||||
|
|
1361
crates/agent2/src/tools/edit_file_tool.rs
Normal file
1361
crates/agent2/src/tools/edit_file_tool.rs
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,8 @@
|
|||
use crate::{AgentTool, ToolCallEventStream};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -8,8 +10,6 @@ use std::fmt::Write;
|
|||
use std::{cmp, path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
|
||||
/// Fast file path pattern matching tool that works with any codebase size
|
||||
///
|
||||
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
|
||||
|
@ -39,8 +39,35 @@ pub struct FindPathToolInput {
|
|||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FindPathToolOutput {
|
||||
paths: Vec<PathBuf>,
|
||||
pub struct FindPathToolOutput {
|
||||
offset: usize,
|
||||
current_matches_page: Vec<PathBuf>,
|
||||
all_matches_len: usize,
|
||||
}
|
||||
|
||||
impl From<FindPathToolOutput> for LanguageModelToolResultContent {
|
||||
fn from(output: FindPathToolOutput) -> Self {
|
||||
if output.current_matches_page.is_empty() {
|
||||
"No matches found".into()
|
||||
} else {
|
||||
let mut llm_output = format!("Found {} total matches.", output.all_matches_len);
|
||||
if output.all_matches_len > RESULTS_PER_PAGE {
|
||||
write!(
|
||||
&mut llm_output,
|
||||
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
|
||||
output.offset + 1,
|
||||
output.offset + output.current_matches_page.len()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for mat in output.current_matches_page {
|
||||
write!(&mut llm_output, "\n{}", mat.display()).unwrap();
|
||||
}
|
||||
|
||||
llm_output.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const RESULTS_PER_PAGE: usize = 50;
|
||||
|
@ -57,6 +84,7 @@ impl FindPathTool {
|
|||
|
||||
impl AgentTool for FindPathTool {
|
||||
type Input = FindPathToolInput;
|
||||
type Output = FindPathToolOutput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"find_path".into()
|
||||
|
@ -75,7 +103,7 @@ impl AgentTool for FindPathTool {
|
|||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
) -> Task<Result<FindPathToolOutput>> {
|
||||
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
|
@ -113,26 +141,11 @@ impl AgentTool for FindPathTool {
|
|||
..Default::default()
|
||||
});
|
||||
|
||||
if matches.is_empty() {
|
||||
Ok("No matches found".into())
|
||||
} else {
|
||||
let mut message = format!("Found {} total matches.", matches.len());
|
||||
if matches.len() > RESULTS_PER_PAGE {
|
||||
write!(
|
||||
&mut message,
|
||||
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
|
||||
input.offset + 1,
|
||||
input.offset + paginated_matches.len()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
|
||||
write!(&mut message, "\n{}", mat.display()).unwrap();
|
||||
}
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
Ok(FindPathToolOutput {
|
||||
offset: input.offset,
|
||||
current_matches_page: paginated_matches.to_vec(),
|
||||
all_matches_len: matches.len(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use assistant_tool::{outline, ActionLog};
|
||||
use gpui::{Entity, Task};
|
||||
use indoc::formatdoc;
|
||||
use language::{Anchor, Point};
|
||||
use project::{AgentLocation, Project, WorktreeSettings};
|
||||
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
|
||||
use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
|
@ -59,6 +60,7 @@ impl ReadFileTool {
|
|||
|
||||
impl AgentTool for ReadFileTool {
|
||||
type Input = ReadFileToolInput;
|
||||
type Output = LanguageModelToolResultContent;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"read_file".into()
|
||||
|
@ -91,9 +93,9 @@ impl AgentTool for ReadFileTool {
|
|||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
) -> Task<Result<LanguageModelToolResultContent>> {
|
||||
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
|
||||
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
|
||||
};
|
||||
|
@ -132,51 +134,27 @@ impl AgentTool for ReadFileTool {
|
|||
|
||||
let file_path = input.path.clone();
|
||||
|
||||
event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
locations: Some(vec![acp::ToolCallLocation {
|
||||
path: project_path.path.to_path_buf(),
|
||||
line: input.start_line,
|
||||
// TODO (tracked): use full range
|
||||
}]),
|
||||
..Default::default()
|
||||
});
|
||||
if image_store::is_image_file(&self.project, &project_path, cx) {
|
||||
return cx.spawn(async move |cx| {
|
||||
let image_entity: Entity<ImageItem> = cx
|
||||
.update(|cx| {
|
||||
self.project.update(cx, |project, cx| {
|
||||
project.open_image(project_path.clone(), cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
// TODO (tracked): images
|
||||
// if image_store::is_image_file(&self.project, &project_path, cx) {
|
||||
// let model = &self.thread.read(cx).selected_model;
|
||||
let image =
|
||||
image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
|
||||
|
||||
// if !model.supports_images() {
|
||||
// return Task::ready(Err(anyhow!(
|
||||
// "Attempted to read an image, but Zed doesn't currently support sending images to {}.",
|
||||
// model.name().0
|
||||
// )))
|
||||
// .into();
|
||||
// }
|
||||
let language_model_image = cx
|
||||
.update(|cx| LanguageModelImage::from_image(image, cx))?
|
||||
.await
|
||||
.context("processing image")?;
|
||||
|
||||
// return cx.spawn(async move |cx| -> Result<ToolResultOutput> {
|
||||
// let image_entity: Entity<ImageItem> = cx
|
||||
// .update(|cx| {
|
||||
// self.project.update(cx, |project, cx| {
|
||||
// project.open_image(project_path.clone(), cx)
|
||||
// })
|
||||
// })?
|
||||
// .await?;
|
||||
|
||||
// let image =
|
||||
// image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
|
||||
|
||||
// let language_model_image = cx
|
||||
// .update(|cx| LanguageModelImage::from_image(image, cx))?
|
||||
// .await
|
||||
// .context("processing image")?;
|
||||
|
||||
// Ok(ToolResultOutput {
|
||||
// content: ToolResultContent::Image(language_model_image),
|
||||
// output: None,
|
||||
// })
|
||||
// });
|
||||
// }
|
||||
//
|
||||
Ok(language_model_image.into())
|
||||
});
|
||||
}
|
||||
|
||||
let project = self.project.clone();
|
||||
let action_log = self.action_log.clone();
|
||||
|
@ -244,7 +222,7 @@ impl AgentTool for ReadFileTool {
|
|||
})?;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
Ok(result.into())
|
||||
} else {
|
||||
// No line ranges specified, so check file size to see if it's too big.
|
||||
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
|
||||
|
@ -257,7 +235,7 @@ impl AgentTool for ReadFileTool {
|
|||
log.buffer_read(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
Ok(result.into())
|
||||
} else {
|
||||
// File is too big, so return the outline
|
||||
// and a suggestion to read again with line numbers.
|
||||
|
@ -276,7 +254,8 @@ impl AgentTool for ReadFileTool {
|
|||
|
||||
Alternatively, you can fall back to the `grep` tool (if available)
|
||||
to search the file for specific content."
|
||||
})
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -285,8 +264,6 @@ impl AgentTool for ReadFileTool {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::TestToolCallEventStream;
|
||||
|
||||
use super::*;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
|
||||
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
|
||||
|
@ -304,7 +281,7 @@ mod test {
|
|||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let (event_stream, _) = ToolCallEventStream::test();
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
|
@ -313,7 +290,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, event_stream, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(
|
||||
|
@ -321,6 +298,7 @@ mod test {
|
|||
"root/nonexistent_file.txt not found"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_small_file(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
@ -336,7 +314,6 @@ mod test {
|
|||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
|
@ -344,10 +321,10 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "This is a small file content");
|
||||
assert_eq!(result.unwrap(), "This is a small file content".into());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -367,18 +344,18 @@ mod test {
|
|||
language_registry.add(Arc::new(rust_lang()));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let content = cx
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/large_file.rs".into(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let content = result.to_str().unwrap();
|
||||
|
||||
assert_eq!(
|
||||
content.lines().skip(4).take(6).collect::<Vec<_>>(),
|
||||
|
@ -399,10 +376,11 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
let content = result.unwrap();
|
||||
.await
|
||||
.unwrap();
|
||||
let content = result.to_str().unwrap();
|
||||
let expected_content = (0..1000)
|
||||
.flat_map(|i| {
|
||||
vec![
|
||||
|
@ -438,7 +416,6 @@ mod test {
|
|||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
|
@ -446,10 +423,10 @@ mod test {
|
|||
start_line: Some(2),
|
||||
end_line: Some(4),
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
|
||||
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -467,7 +444,6 @@ mod test {
|
|||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
// start_line of 0 should be treated as 1
|
||||
let result = cx
|
||||
|
@ -477,10 +453,10 @@ mod test {
|
|||
start_line: Some(0),
|
||||
end_line: Some(2),
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 1\nLine 2");
|
||||
assert_eq!(result.unwrap(), "Line 1\nLine 2".into());
|
||||
|
||||
// end_line of 0 should result in at least 1 line
|
||||
let result = cx
|
||||
|
@ -490,10 +466,10 @@ mod test {
|
|||
start_line: Some(1),
|
||||
end_line: Some(0),
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 1");
|
||||
assert_eq!(result.unwrap(), "Line 1".into());
|
||||
|
||||
// when start_line > end_line, should still return at least 1 line
|
||||
let result = cx
|
||||
|
@ -503,10 +479,10 @@ mod test {
|
|||
start_line: Some(3),
|
||||
end_line: Some(2),
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 3");
|
||||
assert_eq!(result.unwrap(), "Line 3".into());
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
|
@ -612,7 +588,6 @@ mod test {
|
|||
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
// Reading a file outside the project worktree should fail
|
||||
let result = cx
|
||||
|
@ -622,7 +597,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -638,7 +613,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -654,7 +629,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -669,7 +644,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -685,7 +660,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -700,7 +675,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -715,7 +690,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -731,11 +706,11 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_ok(), "Should be able to read normal files");
|
||||
assert_eq!(result.unwrap(), "Normal file content");
|
||||
assert_eq!(result.unwrap(), "Normal file content".into());
|
||||
|
||||
// Path traversal attempts with .. should fail
|
||||
let result = cx
|
||||
|
@ -745,7 +720,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
tool.run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
|
@ -826,7 +801,6 @@ mod test {
|
|||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
// Test reading allowed files in worktree1
|
||||
let result = cx
|
||||
|
@ -836,12 +810,15 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }");
|
||||
assert_eq!(
|
||||
result,
|
||||
"fn main() { println!(\"Hello from worktree1\"); }".into()
|
||||
);
|
||||
|
||||
// Test reading private file in worktree1 should fail
|
||||
let result = cx
|
||||
|
@ -851,7 +828,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
@ -872,7 +849,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
@ -893,14 +870,14 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
"export function greet() { return 'Hello from worktree2'; }"
|
||||
"export function greet() { return 'Hello from worktree2'; }".into()
|
||||
);
|
||||
|
||||
// Test reading private file in worktree2 should fail
|
||||
|
@ -911,7 +888,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
@ -932,7 +909,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
@ -954,7 +931,7 @@ mod test {
|
|||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
tool.clone().run(input, ToolCallEventStream::test().0, cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ pub struct ThinkingTool;
|
|||
|
||||
impl AgentTool for ThinkingTool {
|
||||
type Input = ThinkingToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"thinking".into()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue