Introduce a AgentTool::Output
associated type
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
8f390d9c6d
commit
e06f54054a
6 changed files with 110 additions and 51 deletions
|
@ -14,6 +14,7 @@ pub struct EchoTool;
|
|||
|
||||
impl AgentTool for EchoTool {
|
||||
type Input = EchoToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"echo".into()
|
||||
|
@ -48,6 +49,7 @@ pub struct DelayTool;
|
|||
|
||||
impl AgentTool for DelayTool {
|
||||
type Input = DelayToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"delay".into()
|
||||
|
@ -84,6 +86,7 @@ pub struct ToolRequiringPermission;
|
|||
|
||||
impl AgentTool for ToolRequiringPermission {
|
||||
type Input = ToolRequiringPermissionInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"tool_requiring_permission".into()
|
||||
|
@ -118,6 +121,7 @@ pub struct InfiniteTool;
|
|||
|
||||
impl AgentTool for InfiniteTool {
|
||||
type Input = InfiniteToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"infinite".into()
|
||||
|
@ -168,6 +172,7 @@ pub struct WordListTool;
|
|||
|
||||
impl AgentTool for WordListTool {
|
||||
type Input = WordListInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"word_list".into()
|
||||
|
|
|
@ -517,8 +517,8 @@ impl Thread {
|
|||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
content: tool_output.llm_output,
|
||||
output: Some(tool_output.raw_output),
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
|
@ -664,6 +664,7 @@ where
|
|||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
|
||||
type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
|
||||
|
@ -693,7 +694,7 @@ where
|
|||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
) -> Task<Result<Self::Output>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
|
@ -702,6 +703,11 @@ where
|
|||
|
||||
pub struct Erased<T>(T);
|
||||
|
||||
pub struct AgentToolOutput {
|
||||
llm_output: LanguageModelToolResultContent,
|
||||
raw_output: serde_json::Value,
|
||||
}
|
||||
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
|
@ -713,7 +719,7 @@ pub trait AnyAgentTool {
|
|||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
) -> Task<Result<AgentToolOutput>>;
|
||||
}
|
||||
|
||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||
|
@ -748,12 +754,18 @@ where
|
|||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, event_stream, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
) -> Task<Result<AgentToolOutput>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let input = serde_json::from_value(input)?;
|
||||
let output = cx
|
||||
.update(|cx| self.0.clone().run(input, event_stream, cx))?
|
||||
.await?;
|
||||
let raw_output = serde_json::to_value(&output)?;
|
||||
Ok(AgentToolOutput {
|
||||
llm_output: output.into(),
|
||||
raw_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use acp_thread::Diff;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_tools::edit_agent::{EditAgent, EditAgentOutputEvent, EditFormat};
|
||||
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use indoc::formatdoc;
|
||||
use language::language_settings::{self, FormatOnSave};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use paths;
|
||||
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
||||
use project::{Project, ProjectPath};
|
||||
|
@ -85,6 +86,31 @@ pub enum EditFileMode {
|
|||
Overwrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct EditFileToolOutput {
|
||||
input_path: PathBuf,
|
||||
project_path: PathBuf,
|
||||
new_text: String,
|
||||
old_text: Arc<String>,
|
||||
diff: String,
|
||||
edit_agent_output: EditAgentOutput,
|
||||
}
|
||||
|
||||
impl From<EditFileToolOutput> for LanguageModelToolResultContent {
|
||||
fn from(output: EditFileToolOutput) -> Self {
|
||||
if output.diff.is_empty() {
|
||||
"No edits were made.".into()
|
||||
} else {
|
||||
format!(
|
||||
"Edited {}:\n\n```diff\n{}\n```",
|
||||
output.input_path.display(),
|
||||
output.diff
|
||||
)
|
||||
.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EditFileTool {
|
||||
thread: Entity<Thread>,
|
||||
}
|
||||
|
@ -146,6 +172,7 @@ impl EditFileTool {
|
|||
|
||||
impl AgentTool for EditFileTool {
|
||||
type Input = EditFileToolInput;
|
||||
type Output = EditFileToolOutput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"edit_file".into()
|
||||
|
@ -164,7 +191,7 @@ impl AgentTool for EditFileTool {
|
|||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
) -> Task<Result<Self::Output>> {
|
||||
let project = self.thread.read(cx).project().clone();
|
||||
let project_path = match resolve_path(&input, project.clone(), cx) {
|
||||
Ok(path) => path,
|
||||
|
@ -259,7 +286,7 @@ impl AgentTool for EditFileTool {
|
|||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
let _ = output.await?;
|
||||
let edit_agent_output = output.await?;
|
||||
|
||||
if format_on_save_enabled {
|
||||
action_log.update(cx, |log, cx| {
|
||||
|
@ -287,22 +314,19 @@ impl AgentTool for EditFileTool {
|
|||
})?;
|
||||
|
||||
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let unified_diff = cx
|
||||
let (new_text, unified_diff) = cx
|
||||
.background_spawn({
|
||||
let new_snapshot = new_snapshot.clone();
|
||||
let old_text = old_text.clone();
|
||||
async move {
|
||||
let new_text = new_snapshot.text();
|
||||
language::unified_diff(&old_text, &new_text)
|
||||
let diff = language::unified_diff(&old_text, &new_text);
|
||||
(new_text, diff)
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
println!("\n\n{}\n\n", unified_diff);
|
||||
|
||||
diff.update(cx, |diff, cx| {
|
||||
diff.finalize(cx);
|
||||
}).ok();
|
||||
diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
|
||||
|
||||
let input_path = input.path.display();
|
||||
if unified_diff.is_empty() {
|
||||
|
@ -329,13 +353,16 @@ impl AgentTool for EditFileTool {
|
|||
"}
|
||||
}
|
||||
);
|
||||
Ok("No edits were made.".into())
|
||||
} else {
|
||||
Ok(format!(
|
||||
"Edited {}:\n\n```diff\n{}\n```",
|
||||
input_path, unified_diff
|
||||
))
|
||||
}
|
||||
|
||||
Ok(EditFileToolOutput {
|
||||
input_path: input.path,
|
||||
project_path: project_path.path.to_path_buf(),
|
||||
new_text: new_text.clone(),
|
||||
old_text,
|
||||
diff: unified_diff,
|
||||
edit_agent_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use crate::{AgentTool, ToolCallEventStream};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -8,8 +10,6 @@ use std::fmt::Write;
|
|||
use std::{cmp, path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
|
||||
/// Fast file path pattern matching tool that works with any codebase size
|
||||
///
|
||||
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
|
||||
|
@ -39,8 +39,35 @@ pub struct FindPathToolInput {
|
|||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FindPathToolOutput {
|
||||
paths: Vec<PathBuf>,
|
||||
pub struct FindPathToolOutput {
|
||||
offset: usize,
|
||||
current_matches_page: Vec<PathBuf>,
|
||||
all_matches_len: usize,
|
||||
}
|
||||
|
||||
impl From<FindPathToolOutput> for LanguageModelToolResultContent {
|
||||
fn from(output: FindPathToolOutput) -> Self {
|
||||
if output.current_matches_page.is_empty() {
|
||||
"No matches found".into()
|
||||
} else {
|
||||
let mut llm_output = format!("Found {} total matches.", output.all_matches_len);
|
||||
if output.all_matches_len > RESULTS_PER_PAGE {
|
||||
write!(
|
||||
&mut llm_output,
|
||||
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
|
||||
output.offset + 1,
|
||||
output.offset + output.current_matches_page.len()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for mat in output.current_matches_page {
|
||||
write!(&mut llm_output, "\n{}", mat.display()).unwrap();
|
||||
}
|
||||
|
||||
llm_output.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const RESULTS_PER_PAGE: usize = 50;
|
||||
|
@ -57,6 +84,7 @@ impl FindPathTool {
|
|||
|
||||
impl AgentTool for FindPathTool {
|
||||
type Input = FindPathToolInput;
|
||||
type Output = FindPathToolOutput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"find_path".into()
|
||||
|
@ -75,7 +103,7 @@ impl AgentTool for FindPathTool {
|
|||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
) -> Task<Result<FindPathToolOutput>> {
|
||||
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
|
@ -113,26 +141,11 @@ impl AgentTool for FindPathTool {
|
|||
..Default::default()
|
||||
});
|
||||
|
||||
if matches.is_empty() {
|
||||
Ok("No matches found".into())
|
||||
} else {
|
||||
let mut message = format!("Found {} total matches.", matches.len());
|
||||
if matches.len() > RESULTS_PER_PAGE {
|
||||
write!(
|
||||
&mut message,
|
||||
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
|
||||
input.offset + 1,
|
||||
input.offset + paginated_matches.len()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
|
||||
write!(&mut message, "\n{}", mat.display()).unwrap();
|
||||
}
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
Ok(FindPathToolOutput {
|
||||
offset: input.offset,
|
||||
current_matches_page: paginated_matches.to_vec(),
|
||||
all_matches_len: matches.len(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,6 +59,7 @@ impl ReadFileTool {
|
|||
|
||||
impl AgentTool for ReadFileTool {
|
||||
type Input = ReadFileToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"read_file".into()
|
||||
|
|
|
@ -20,6 +20,7 @@ pub struct ThinkingTool;
|
|||
|
||||
impl AgentTool for ThinkingTool {
|
||||
type Input = ThinkingToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"thinking".into()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue