Introduce a AgentTool::Output associated type

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-08 12:24:11 +02:00
parent 8f390d9c6d
commit e06f54054a
6 changed files with 110 additions and 51 deletions

View file

@ -14,6 +14,7 @@ pub struct EchoTool;
impl AgentTool for EchoTool {
type Input = EchoToolInput;
type Output = String;
fn name(&self) -> SharedString {
"echo".into()
@ -48,6 +49,7 @@ pub struct DelayTool;
impl AgentTool for DelayTool {
type Input = DelayToolInput;
type Output = String;
fn name(&self) -> SharedString {
"delay".into()
@ -84,6 +86,7 @@ pub struct ToolRequiringPermission;
impl AgentTool for ToolRequiringPermission {
type Input = ToolRequiringPermissionInput;
type Output = String;
fn name(&self) -> SharedString {
"tool_requiring_permission".into()
@ -118,6 +121,7 @@ pub struct InfiniteTool;
impl AgentTool for InfiniteTool {
type Input = InfiniteToolInput;
type Output = String;
fn name(&self) -> SharedString {
"infinite".into()
@ -168,6 +172,7 @@ pub struct WordListTool;
impl AgentTool for WordListTool {
type Input = WordListInput;
type Output = String;
fn name(&self) -> SharedString {
"word_list".into()

View file

@ -517,8 +517,8 @@ impl Thread {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
output: None,
content: tool_output.llm_output,
output: Some(tool_output.raw_output),
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
@ -664,6 +664,7 @@ where
Self: 'static + Sized,
{
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
fn name(&self) -> SharedString;
@ -693,7 +694,7 @@ where
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>;
) -> Task<Result<Self::Output>>;
fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self)))
@ -702,6 +703,11 @@ where
pub struct Erased<T>(T);
pub struct AgentToolOutput {
llm_output: LanguageModelToolResultContent,
raw_output: serde_json::Value,
}
pub trait AnyAgentTool {
fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString;
@ -713,7 +719,7 @@ pub trait AnyAgentTool {
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>>;
) -> Task<Result<AgentToolOutput>>;
}
impl<T> AnyAgentTool for Erased<Arc<T>>
@ -748,12 +754,18 @@ where
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
match parsed_input {
Ok(input) => self.0.clone().run(input, event_stream, cx),
Err(error) => Task::ready(Err(anyhow!(error))),
}
) -> Task<Result<AgentToolOutput>> {
cx.spawn(async move |cx| {
let input = serde_json::from_value(input)?;
let output = cx
.update(|cx| self.0.clone().run(input, event_stream, cx))?
.await?;
let raw_output = serde_json::to_value(&output)?;
Ok(AgentToolOutput {
llm_output: output.into(),
raw_output,
})
})
}
}

View file

@ -1,12 +1,13 @@
use acp_thread::Diff;
use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutputEvent, EditFormat};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc;
use language::language_settings::{self, FormatOnSave};
use language_model::LanguageModelToolResultContent;
use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
use project::{Project, ProjectPath};
@ -85,6 +86,31 @@ pub enum EditFileMode {
Overwrite,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput {
input_path: PathBuf,
project_path: PathBuf,
new_text: String,
old_text: Arc<String>,
diff: String,
edit_agent_output: EditAgentOutput,
}
impl From<EditFileToolOutput> for LanguageModelToolResultContent {
fn from(output: EditFileToolOutput) -> Self {
if output.diff.is_empty() {
"No edits were made.".into()
} else {
format!(
"Edited {}:\n\n```diff\n{}\n```",
output.input_path.display(),
output.diff
)
.into()
}
}
}
pub struct EditFileTool {
thread: Entity<Thread>,
}
@ -146,6 +172,7 @@ impl EditFileTool {
impl AgentTool for EditFileTool {
type Input = EditFileToolInput;
type Output = EditFileToolOutput;
fn name(&self) -> SharedString {
"edit_file".into()
@ -164,7 +191,7 @@ impl AgentTool for EditFileTool {
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
) -> Task<Result<Self::Output>> {
let project = self.thread.read(cx).project().clone();
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
@ -259,7 +286,7 @@ impl AgentTool for EditFileTool {
})
.unwrap_or(false);
let _ = output.await?;
let edit_agent_output = output.await?;
if format_on_save_enabled {
action_log.update(cx, |log, cx| {
@ -287,22 +314,19 @@ impl AgentTool for EditFileTool {
})?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let unified_diff = cx
let (new_text, unified_diff) = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
let old_text = old_text.clone();
async move {
let new_text = new_snapshot.text();
language::unified_diff(&old_text, &new_text)
let diff = language::unified_diff(&old_text, &new_text);
(new_text, diff)
}
})
.await;
println!("\n\n{}\n\n", unified_diff);
diff.update(cx, |diff, cx| {
diff.finalize(cx);
}).ok();
diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
let input_path = input.path.display();
if unified_diff.is_empty() {
@ -329,13 +353,16 @@ impl AgentTool for EditFileTool {
"}
}
);
Ok("No edits were made.".into())
} else {
Ok(format!(
"Edited {}:\n\n```diff\n{}\n```",
input_path, unified_diff
))
}
Ok(EditFileToolOutput {
input_path: input.path,
project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text,
diff: unified_diff,
edit_agent_output,
})
})
}
}

View file

@ -1,6 +1,8 @@
use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use gpui::{App, AppContext, Entity, SharedString, Task};
use language_model::LanguageModelToolResultContent;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -8,8 +10,6 @@ use std::fmt::Write;
use std::{cmp, path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
use crate::{AgentTool, ToolCallEventStream};
/// Fast file path pattern matching tool that works with any codebase size
///
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
@ -39,8 +39,35 @@ pub struct FindPathToolInput {
}
#[derive(Debug, Serialize, Deserialize)]
struct FindPathToolOutput {
paths: Vec<PathBuf>,
pub struct FindPathToolOutput {
offset: usize,
current_matches_page: Vec<PathBuf>,
all_matches_len: usize,
}
impl From<FindPathToolOutput> for LanguageModelToolResultContent {
fn from(output: FindPathToolOutput) -> Self {
if output.current_matches_page.is_empty() {
"No matches found".into()
} else {
let mut llm_output = format!("Found {} total matches.", output.all_matches_len);
if output.all_matches_len > RESULTS_PER_PAGE {
write!(
&mut llm_output,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
output.offset + 1,
output.offset + output.current_matches_page.len()
)
.unwrap();
}
for mat in output.current_matches_page {
write!(&mut llm_output, "\n{}", mat.display()).unwrap();
}
llm_output.into()
}
}
}
const RESULTS_PER_PAGE: usize = 50;
@ -57,6 +84,7 @@ impl FindPathTool {
impl AgentTool for FindPathTool {
type Input = FindPathToolInput;
type Output = FindPathToolOutput;
fn name(&self) -> SharedString {
"find_path".into()
@ -75,7 +103,7 @@ impl AgentTool for FindPathTool {
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
) -> Task<Result<FindPathToolOutput>> {
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
cx.background_spawn(async move {
@ -113,26 +141,11 @@ impl AgentTool for FindPathTool {
..Default::default()
});
if matches.is_empty() {
Ok("No matches found".into())
} else {
let mut message = format!("Found {} total matches.", matches.len());
if matches.len() > RESULTS_PER_PAGE {
write!(
&mut message,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
input.offset + 1,
input.offset + paginated_matches.len()
)
.unwrap();
}
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
}
Ok(FindPathToolOutput {
offset: input.offset,
current_matches_page: paginated_matches.to_vec(),
all_matches_len: matches.len(),
})
})
}
}

View file

@ -59,6 +59,7 @@ impl ReadFileTool {
impl AgentTool for ReadFileTool {
type Input = ReadFileToolInput;
type Output = String;
fn name(&self) -> SharedString {
"read_file".into()

View file

@ -20,6 +20,7 @@ pub struct ThinkingTool;
impl AgentTool for ThinkingTool {
type Input = ThinkingToolInput;
type Output = String;
fn name(&self) -> SharedString {
"thinking".into()