Introduce a AgentTool::Output associated type

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-08 12:24:11 +02:00
parent 8f390d9c6d
commit e06f54054a
6 changed files with 110 additions and 51 deletions

View file

@ -14,6 +14,7 @@ pub struct EchoTool;
impl AgentTool for EchoTool { impl AgentTool for EchoTool {
type Input = EchoToolInput; type Input = EchoToolInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"echo".into() "echo".into()
@ -48,6 +49,7 @@ pub struct DelayTool;
impl AgentTool for DelayTool { impl AgentTool for DelayTool {
type Input = DelayToolInput; type Input = DelayToolInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"delay".into() "delay".into()
@ -84,6 +86,7 @@ pub struct ToolRequiringPermission;
impl AgentTool for ToolRequiringPermission { impl AgentTool for ToolRequiringPermission {
type Input = ToolRequiringPermissionInput; type Input = ToolRequiringPermissionInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"tool_requiring_permission".into() "tool_requiring_permission".into()
@ -118,6 +121,7 @@ pub struct InfiniteTool;
impl AgentTool for InfiniteTool { impl AgentTool for InfiniteTool {
type Input = InfiniteToolInput; type Input = InfiniteToolInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"infinite".into() "infinite".into()
@ -168,6 +172,7 @@ pub struct WordListTool;
impl AgentTool for WordListTool { impl AgentTool for WordListTool {
type Input = WordListInput; type Input = WordListInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"word_list".into() "word_list".into()

View file

@ -517,8 +517,8 @@ impl Thread {
tool_use_id: tool_use.id, tool_use_id: tool_use.id,
tool_name: tool_use.name, tool_name: tool_use.name,
is_error: false, is_error: false,
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)), content: tool_output.llm_output,
output: None, output: Some(tool_output.raw_output),
}, },
Err(error) => LanguageModelToolResult { Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id, tool_use_id: tool_use.id,
@ -664,6 +664,7 @@ where
Self: 'static + Sized, Self: 'static + Sized,
{ {
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema; type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
fn name(&self) -> SharedString; fn name(&self) -> SharedString;
@ -693,7 +694,7 @@ where
input: Self::Input, input: Self::Input,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>>; ) -> Task<Result<Self::Output>>;
fn erase(self) -> Arc<dyn AnyAgentTool> { fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self))) Arc::new(Erased(Arc::new(self)))
@ -702,6 +703,11 @@ where
pub struct Erased<T>(T); pub struct Erased<T>(T);
pub struct AgentToolOutput {
llm_output: LanguageModelToolResultContent,
raw_output: serde_json::Value,
}
pub trait AnyAgentTool { pub trait AnyAgentTool {
fn name(&self) -> SharedString; fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString; fn description(&self, cx: &mut App) -> SharedString;
@ -713,7 +719,7 @@ pub trait AnyAgentTool {
input: serde_json::Value, input: serde_json::Value,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>>; ) -> Task<Result<AgentToolOutput>>;
} }
impl<T> AnyAgentTool for Erased<Arc<T>> impl<T> AnyAgentTool for Erased<Arc<T>>
@ -748,12 +754,18 @@ where
input: serde_json::Value, input: serde_json::Value,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<AgentToolOutput>> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into); cx.spawn(async move |cx| {
match parsed_input { let input = serde_json::from_value(input)?;
Ok(input) => self.0.clone().run(input, event_stream, cx), let output = cx
Err(error) => Task::ready(Err(anyhow!(error))), .update(|cx| self.0.clone().run(input, event_stream, cx))?
} .await?;
let raw_output = serde_json::to_value(&output)?;
Ok(AgentToolOutput {
llm_output: output.into(),
raw_output,
})
})
} }
} }

View file

@ -1,12 +1,13 @@
use acp_thread::Diff; use acp_thread::Diff;
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutputEvent, EditFormat}; use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent; use cloud_llm_client::CompletionIntent;
use collections::HashSet; use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task}; use gpui::{App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc; use indoc::formatdoc;
use language::language_settings::{self, FormatOnSave}; use language::language_settings::{self, FormatOnSave};
use language_model::LanguageModelToolResultContent;
use paths; use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget}; use project::lsp_store::{FormatTrigger, LspFormatTarget};
use project::{Project, ProjectPath}; use project::{Project, ProjectPath};
@ -85,6 +86,31 @@ pub enum EditFileMode {
Overwrite, Overwrite,
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput {
input_path: PathBuf,
project_path: PathBuf,
new_text: String,
old_text: Arc<String>,
diff: String,
edit_agent_output: EditAgentOutput,
}
impl From<EditFileToolOutput> for LanguageModelToolResultContent {
fn from(output: EditFileToolOutput) -> Self {
if output.diff.is_empty() {
"No edits were made.".into()
} else {
format!(
"Edited {}:\n\n```diff\n{}\n```",
output.input_path.display(),
output.diff
)
.into()
}
}
}
pub struct EditFileTool { pub struct EditFileTool {
thread: Entity<Thread>, thread: Entity<Thread>,
} }
@ -146,6 +172,7 @@ impl EditFileTool {
impl AgentTool for EditFileTool { impl AgentTool for EditFileTool {
type Input = EditFileToolInput; type Input = EditFileToolInput;
type Output = EditFileToolOutput;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"edit_file".into() "edit_file".into()
@ -164,7 +191,7 @@ impl AgentTool for EditFileTool {
input: Self::Input, input: Self::Input,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<Self::Output>> {
let project = self.thread.read(cx).project().clone(); let project = self.thread.read(cx).project().clone();
let project_path = match resolve_path(&input, project.clone(), cx) { let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path, Ok(path) => path,
@ -259,7 +286,7 @@ impl AgentTool for EditFileTool {
}) })
.unwrap_or(false); .unwrap_or(false);
let _ = output.await?; let edit_agent_output = output.await?;
if format_on_save_enabled { if format_on_save_enabled {
action_log.update(cx, |log, cx| { action_log.update(cx, |log, cx| {
@ -287,22 +314,19 @@ impl AgentTool for EditFileTool {
})?; })?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let unified_diff = cx let (new_text, unified_diff) = cx
.background_spawn({ .background_spawn({
let new_snapshot = new_snapshot.clone(); let new_snapshot = new_snapshot.clone();
let old_text = old_text.clone(); let old_text = old_text.clone();
async move { async move {
let new_text = new_snapshot.text(); let new_text = new_snapshot.text();
language::unified_diff(&old_text, &new_text) let diff = language::unified_diff(&old_text, &new_text);
(new_text, diff)
} }
}) })
.await; .await;
println!("\n\n{}\n\n", unified_diff); diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
diff.update(cx, |diff, cx| {
diff.finalize(cx);
}).ok();
let input_path = input.path.display(); let input_path = input.path.display();
if unified_diff.is_empty() { if unified_diff.is_empty() {
@ -329,13 +353,16 @@ impl AgentTool for EditFileTool {
"} "}
} }
); );
Ok("No edits were made.".into())
} else {
Ok(format!(
"Edited {}:\n\n```diff\n{}\n```",
input_path, unified_diff
))
} }
Ok(EditFileToolOutput {
input_path: input.path,
project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text,
diff: unified_diff,
edit_agent_output,
})
}) })
} }
} }

View file

@ -1,6 +1,8 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use gpui::{App, AppContext, Entity, SharedString, Task}; use gpui::{App, AppContext, Entity, SharedString, Task};
use language_model::LanguageModelToolResultContent;
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -8,8 +10,6 @@ use std::fmt::Write;
use std::{cmp, path::PathBuf, sync::Arc}; use std::{cmp, path::PathBuf, sync::Arc};
use util::paths::PathMatcher; use util::paths::PathMatcher;
use crate::{AgentTool, ToolCallEventStream};
/// Fast file path pattern matching tool that works with any codebase size /// Fast file path pattern matching tool that works with any codebase size
/// ///
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts" /// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
@ -39,8 +39,35 @@ pub struct FindPathToolInput {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct FindPathToolOutput { pub struct FindPathToolOutput {
paths: Vec<PathBuf>, offset: usize,
current_matches_page: Vec<PathBuf>,
all_matches_len: usize,
}
impl From<FindPathToolOutput> for LanguageModelToolResultContent {
fn from(output: FindPathToolOutput) -> Self {
if output.current_matches_page.is_empty() {
"No matches found".into()
} else {
let mut llm_output = format!("Found {} total matches.", output.all_matches_len);
if output.all_matches_len > RESULTS_PER_PAGE {
write!(
&mut llm_output,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
output.offset + 1,
output.offset + output.current_matches_page.len()
)
.unwrap();
}
for mat in output.current_matches_page {
write!(&mut llm_output, "\n{}", mat.display()).unwrap();
}
llm_output.into()
}
}
} }
const RESULTS_PER_PAGE: usize = 50; const RESULTS_PER_PAGE: usize = 50;
@ -57,6 +84,7 @@ impl FindPathTool {
impl AgentTool for FindPathTool { impl AgentTool for FindPathTool {
type Input = FindPathToolInput; type Input = FindPathToolInput;
type Output = FindPathToolOutput;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"find_path".into() "find_path".into()
@ -75,7 +103,7 @@ impl AgentTool for FindPathTool {
input: Self::Input, input: Self::Input,
event_stream: ToolCallEventStream, event_stream: ToolCallEventStream,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<FindPathToolOutput>> {
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
cx.background_spawn(async move { cx.background_spawn(async move {
@ -113,26 +141,11 @@ impl AgentTool for FindPathTool {
..Default::default() ..Default::default()
}); });
if matches.is_empty() { Ok(FindPathToolOutput {
Ok("No matches found".into()) offset: input.offset,
} else { current_matches_page: paginated_matches.to_vec(),
let mut message = format!("Found {} total matches.", matches.len()); all_matches_len: matches.len(),
if matches.len() > RESULTS_PER_PAGE { })
write!(
&mut message,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
input.offset + 1,
input.offset + paginated_matches.len()
)
.unwrap();
}
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
}
}) })
} }
} }

View file

@ -59,6 +59,7 @@ impl ReadFileTool {
impl AgentTool for ReadFileTool { impl AgentTool for ReadFileTool {
type Input = ReadFileToolInput; type Input = ReadFileToolInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"read_file".into() "read_file".into()

View file

@ -20,6 +20,7 @@ pub struct ThinkingTool;
impl AgentTool for ThinkingTool { impl AgentTool for ThinkingTool {
type Input = ThinkingToolInput; type Input = ThinkingToolInput;
type Output = String;
fn name(&self) -> SharedString { fn name(&self) -> SharedString {
"thinking".into() "thinking".into()