Fix bash tool output (#28391)

This commit is contained in:
Agus Zubiaga 2025-04-09 08:20:24 -06:00 committed by GitHub
parent 3a8fe4d973
commit 1cb4f8288d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool}; use assistant_tool::{ActionLog, Tool};
use futures::io::BufReader; use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt}; use futures::{AsyncBufReadExt, AsyncReadExt};
use gpui::{App, Entity, Task}; use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
@ -123,9 +123,15 @@ impl Tool for BashTool {
worktree.read(cx).abs_path() worktree.read(cx).abs_path()
}; };
cx.spawn(async move |_| { cx.background_spawn(run_command_limited(working_dir, input.command))
}
}
const LIMIT: usize = 16 * 1024;
async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
// Add 2>&1 to merge stderr into stdout for proper interleaving. // Add 2>&1 to merge stderr into stdout for proper interleaving.
let command = format!("({}) 2>&1", input.command); let command = format!("({}) 2>&1", command);
let mut cmd = new_smol_command("bash") let mut cmd = new_smol_command("bash")
.arg("-c") .arg("-c")
@ -139,28 +145,13 @@ impl Tool for BashTool {
let stdout = cmd.stdout.take().unwrap(); let stdout = cmd.stdout.take().unwrap();
let mut reader = BufReader::new(stdout); let mut reader = BufReader::new(stdout);
const MESSAGE_1: &str = "Command output too long. The first ";
const MESSAGE_2: &str = " bytes:\n\n";
const ERR_MESSAGE_1: &str = "Command failed with exit code ";
const ERR_MESSAGE_2: &str = "\n\n";
const STDOUT_LIMIT: usize = 8192;
const LIMIT: usize = STDOUT_LIMIT
- (MESSAGE_1.len()
+ (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
+ MESSAGE_2.len()
+ ERR_MESSAGE_1.len()
+ 3 // status code
+ ERR_MESSAGE_2.len());
// Read one more byte to determine whether the output was truncated // Read one more byte to determine whether the output was truncated
let mut buffer = vec![0; LIMIT + 1]; let mut buffer = vec![0; LIMIT + 1];
let mut bytes_read = 0; let mut bytes_read = 0;
// Read until we reach the limit // Read until we reach the limit
loop { loop {
let read = reader.read(&mut buffer).await?; let read = reader.read(&mut buffer[bytes_read..]).await?;
if read == 0 { if read == 0 {
break; break;
} }
@ -182,28 +173,24 @@ impl Tool for BashTool {
reader.consume_unpin(skipped_bytes_len); reader.consume_unpin(skipped_bytes_len);
} }
let output_bytes = &buffer[..bytes_read]; let output_bytes = &buffer[..bytes_read.min(LIMIT)];
// Let the process continue running
let status = cmd.status().await.context("Failed to get command status")?; let status = cmd.status().await.context("Failed to get command status")?;
let output_string = if bytes_read > LIMIT { let output_string = if bytes_read > LIMIT {
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
// multi-byte characters. // multi-byte characters.
let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n'); let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
let output_string = String::from_utf8_lossy( let until_last_line = &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())];
&output_bytes[..last_line_ix.unwrap_or(output_bytes.len())], let output_string = String::from_utf8_lossy(until_last_line);
);
format!( format!(
"{}{}{}{}", "Command output too long. The first {} bytes:\n\n{}",
MESSAGE_1,
output_string.len(), output_string.len(),
MESSAGE_2, output_block(&output_string),
output_string
) )
} else { } else {
String::from_utf8_lossy(&output_bytes).into() output_block(&String::from_utf8_lossy(&output_bytes))
}; };
let output_with_status = if status.success() { let output_with_status = if status.success() {
@ -214,17 +201,106 @@ impl Tool for BashTool {
} }
} else { } else {
format!( format!(
"{}{}{}{}", "Command failed with exit code {}\n\n{}",
ERR_MESSAGE_1,
status.code().unwrap_or(-1), status.code().unwrap_or(-1),
ERR_MESSAGE_2,
output_string, output_string,
) )
}; };
debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
Ok(output_with_status) Ok(output_with_status)
}) }
fn output_block(output: &str) -> String {
format!(
"```\n{}{}```",
output,
if output.ends_with('\n') { "" } else { "\n" }
)
}
#[cfg(test)]
#[cfg(not(windows))]
mod tests {
use gpui::TestAppContext;
use super::*;
#[gpui::test]
async fn test_run_command_simple(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let result =
run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\nHello, World!\n```");
}
#[gpui::test]
async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let command =
"echo 'stdout 1' && echo 'stderr 1' >&2 && echo 'stdout 2' && echo 'stderr 2' >&2";
let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
"```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
);
}
#[gpui::test]
async fn test_multiple_output_reads(cx: &mut TestAppContext) {
cx.executor().allow_parking();
// Command with multiple outputs that might require multiple reads
let result = run_command_limited(
Path::new(".").into(),
"echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
}
#[gpui::test]
async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}';", "X".repeat(LIMIT * 2));
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
// Output should be exactly the limit
assert_eq!(content_length, LIMIT);
}
#[gpui::test]
async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
assert!(content_length <= LIMIT);
} }
} }