Add Batch tool call for calling multiple tools (#27621)

<img width="620" alt="Screenshot 2025-03-27 at 2 29 13 PM"
src="https://github.com/user-attachments/assets/dd023507-61bc-4722-a095-f65f4b6c746a"
/>

We'll iterate on the UI, but first the goal is to just get it to work at
all so we can see if it's useful in terms of getting correct output
faster.

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <hi@aguz.me>
This commit is contained in:
Richard Feldman 2025-03-27 18:21:26 -04:00 committed by GitHub
parent 61be869352
commit 56eb650f09
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 314 additions and 0 deletions

View file

@ -1,4 +1,5 @@
mod bash_tool;
mod batch_tool;
mod copy_path_tool;
mod create_directory_tool;
mod create_file_tool;
@ -26,6 +27,7 @@ use http_client::HttpClientWithUrl;
use move_path_tool::MovePathTool;
use crate::bash_tool::BashTool;
use crate::batch_tool::BatchTool;
use crate::create_directory_tool::CreateDirectoryTool;
use crate::create_file_tool::CreateFileTool;
use crate::delete_path_tool::DeletePathTool;
@ -47,6 +49,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
let registry = ToolRegistry::global(cx);
registry.register_tool(BashTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);

View file

@ -0,0 +1,301 @@
use anyhow::{anyhow, Result};
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use futures::future::join_all;
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::IconName;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ToolInvocation {
/// The name of the tool to invoke
pub name: String,
/// The input to the tool in JSON format
pub input: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct BatchToolInput {
/// The tool invocations to run as a batch. These tools will be run either sequentially
/// or concurrently depending on the `run_tools_concurrently` flag.
///
/// <example>
/// Basic file operations (concurrent)
///
/// ```json
/// {
/// "invocations": [
/// {
/// "name": "read-file",
/// "input": {
/// "path": "src/main.rs"
/// }
/// },
/// {
/// "name": "list-directory",
/// "input": {
/// "path": "src/lib"
/// }
/// },
/// {
/// "name": "regex-search",
/// "input": {
/// "regex": "fn run\\("
/// }
/// }
/// ],
/// "run_tools_concurrently": true
/// }
/// ```
/// </example>
///
/// <example>
/// Multiple find-replace operations on the same file (sequential)
///
/// ```json
/// {
/// "invocations": [
/// {
/// "name": "find-replace-file",
/// "input": {
/// "path": "src/config.rs",
/// "display_description": "Update default timeout value",
/// "find": "pub const DEFAULT_TIMEOUT: u64 = 30;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";",
/// "replace": "pub const DEFAULT_TIMEOUT: u64 = 60;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";"
/// }
/// },
/// {
/// "name": "find-replace-file",
/// "input": {
/// "path": "src/config.rs",
/// "display_description": "Update API endpoint URL",
/// "find": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";\n\npub const API_VERSION: &str = \"v1\";",
/// "replace": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.newdomain.com\";\n\npub const API_VERSION: &str = \"v1\";"
/// }
/// }
/// ],
/// "run_tools_concurrently": false
/// }
/// ```
/// </example>
///
/// <example>
/// Searching and analyzing code (concurrent)
///
/// ```json
/// {
/// "invocations": [
/// {
/// "name": "regex-search",
/// "input": {
/// "regex": "impl Database"
/// }
/// },
/// {
/// "name": "path-search",
/// "input": {
/// "glob": "**/*test*.rs"
/// }
/// }
/// ],
/// "run_tools_concurrently": true
/// }
/// ```
/// </example>
///
/// <example>
/// Multi-file refactoring (concurrent)
///
/// ```json
/// {
/// "invocations": [
/// {
/// "name": "find-replace-file",
/// "input": {
/// "path": "src/models/user.rs",
/// "display_description": "Add email field to User struct",
/// "find": "pub struct User {\n pub id: u64,\n pub username: String,\n pub created_at: DateTime<Utc>,\n}",
/// "replace": "pub struct User {\n pub id: u64,\n pub username: String,\n pub email: String,\n pub created_at: DateTime<Utc>,\n}"
/// }
/// },
/// {
/// "name": "find-replace-file",
/// "input": {
/// "path": "src/db/queries.rs",
/// "display_description": "Update user insertion query",
/// "find": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, created_at) VALUES ($1, $2, $3)\",\n &[&user.id, &user.username, &user.created_at],\n ).await?;\n \n Ok(())\n}",
/// "replace": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, email, created_at) VALUES ($1, $2, $3, $4)\",\n &[&user.id, &user.username, &user.email, &user.created_at],\n ).await?;\n \n Ok(())\n}"
/// }
/// }
/// ],
/// "run_tools_concurrently": true
/// }
/// ```
/// </example>
pub invocations: Vec<ToolInvocation>,
/// Whether to run the tools in this batch concurrently. If this is false (the default), the tools will run sequentially.
#[serde(default)]
pub run_tools_concurrently: bool,
}
pub struct BatchTool;
impl Tool for BatchTool {
fn name(&self) -> String {
"batch-tool".into()
}
fn needs_confirmation(&self) -> bool {
true
}
fn description(&self) -> String {
include_str!("./batch_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Cog
}
fn input_schema(&self) -> serde_json::Value {
let schema = schemars::schema_for!(BatchToolInput);
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<BatchToolInput>(input.clone()) {
Ok(input) => {
let count = input.invocations.len();
let mode = if input.run_tools_concurrently {
"concurrently"
} else {
"sequentially"
};
let first_tool_name = input
.invocations
.first()
.map(|inv| inv.name.clone())
.unwrap_or_default();
let all_same = input
.invocations
.iter()
.all(|invocation| invocation.name == first_tool_name);
if all_same {
format!(
"Run `{}` {} times {}",
first_tool_name,
input.invocations.len(),
mode
)
} else {
format!("Run {} tools {}", count, mode)
}
}
Err(_) => "Batch tools".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<BatchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
if input.invocations.is_empty() {
return Task::ready(Err(anyhow!("No tool invocations provided")));
}
let run_tools_concurrently = input.run_tools_concurrently;
let foreground_task = {
let working_set = ToolWorkingSet::default();
let invocations = input.invocations;
let messages = messages.to_vec();
cx.spawn(async move |cx| {
let mut tasks = Vec::new();
let mut tool_names = Vec::new();
for invocation in invocations {
let tool_name = invocation.name.clone();
tool_names.push(tool_name.clone());
let tool = cx
.update(|cx| working_set.tool(&tool_name, cx))
.map_err(|err| {
anyhow!("Failed to look up tool '{}': {}", tool_name, err)
})?;
let Some(tool) = tool else {
return Err(anyhow!("Tool '{}' not found", tool_name));
};
let project = project.clone();
let action_log = action_log.clone();
let messages = messages.clone();
let task = cx
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
tasks.push(task);
}
Ok((tasks, tool_names))
})
};
cx.background_spawn(async move {
let (tasks, tool_names) = foreground_task.await?;
let mut results = Vec::with_capacity(tasks.len());
if run_tools_concurrently {
results.extend(join_all(tasks).await)
} else {
for task in tasks {
results.push(task.await);
}
};
let mut formatted_results = String::new();
let mut error_occurred = false;
for (i, result) in results.into_iter().enumerate() {
let tool_name = &tool_names[i];
match result {
Ok(output) => {
formatted_results
.push_str(&format!("Tool '{}' result:\n{}\n\n", tool_name, output));
}
Err(err) => {
error_occurred = true;
formatted_results
.push_str(&format!("Tool '{}' error: {}\n\n", tool_name, err));
}
}
}
if error_occurred {
formatted_results
.push_str("Note: Some tool invocations failed. See individual results above.");
}
Ok(formatted_results.trim().to_string())
})
}
}

View file

@ -0,0 +1,9 @@
Invoke multiple other tool calls either sequentially or concurrently.
This tool is useful when you need to perform several operations at once, improving efficiency by reducing the number of back-and-forth interactions needed to complete complex tasks.
If the tool calls are set to be run sequentially, then each tool call within the batch is executed in the order provided. If it's set to run concurrently, then they may run in a different order. Regardless, all tool calls will have the same permissions and context as if they were called individually.
This tool should never be used to run a total of one tool. Instead, just run that one tool directly. You can run batches within batches if desired, which is a way you can mix concurrent and sequential tool call execution.
When it's possible to run tools in a batch, you should run as many as possible in the batch, up to a maximum of 32. For example, don't run multiple consecutive batches of 10 when you could instead run one batch of 30.