Fix bash tool output (#28391)

This commit is contained in:
Agus Zubiaga 2025-04-09 08:20:24 -06:00 committed by GitHub
parent 3a8fe4d973
commit 1cb4f8288d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt};
use gpui::{App, Entity, Task};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@ -123,108 +123,184 @@ impl Tool for BashTool {
worktree.read(cx).abs_path()
};
cx.spawn(async move |_| {
// Add 2>&1 to merge stderr into stdout for proper interleaving.
let command = format!("({}) 2>&1", input.command);
let mut cmd = new_smol_command("bash")
.arg("-c")
.arg(&command)
.current_dir(working_dir)
.stdout(std::process::Stdio::piped())
.spawn()
.context("Failed to execute bash command")?;
// Capture stdout with a limit
let stdout = cmd.stdout.take().unwrap();
let mut reader = BufReader::new(stdout);
const MESSAGE_1: &str = "Command output too long. The first ";
const MESSAGE_2: &str = " bytes:\n\n";
const ERR_MESSAGE_1: &str = "Command failed with exit code ";
const ERR_MESSAGE_2: &str = "\n\n";
const STDOUT_LIMIT: usize = 8192;
const LIMIT: usize = STDOUT_LIMIT
- (MESSAGE_1.len()
+ (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
+ MESSAGE_2.len()
+ ERR_MESSAGE_1.len()
+ 3 // status code
+ ERR_MESSAGE_2.len());
// Read one more byte to determine whether the output was truncated
let mut buffer = vec![0; LIMIT + 1];
let mut bytes_read = 0;
// Read until we reach the limit
loop {
let read = reader.read(&mut buffer).await?;
if read == 0 {
break;
}
bytes_read += read;
if bytes_read > LIMIT {
bytes_read = LIMIT + 1;
break;
}
}
// Repeatedly fill the output reader's buffer without copying it.
loop {
let skipped_bytes = reader.fill_buf().await?;
if skipped_bytes.is_empty() {
break;
}
let skipped_bytes_len = skipped_bytes.len();
reader.consume_unpin(skipped_bytes_len);
}
let output_bytes = &buffer[..bytes_read];
// Let the process continue running
let status = cmd.status().await.context("Failed to get command status")?;
let output_string = if bytes_read > LIMIT {
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
// multi-byte characters.
let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
let output_string = String::from_utf8_lossy(
&output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
);
format!(
"{}{}{}{}",
MESSAGE_1,
output_string.len(),
MESSAGE_2,
output_string
)
} else {
String::from_utf8_lossy(&output_bytes).into()
};
let output_with_status = if status.success() {
if output_string.is_empty() {
"Command executed successfully.".to_string()
} else {
output_string.to_string()
}
} else {
format!(
"{}{}{}{}",
ERR_MESSAGE_1,
status.code().unwrap_or(-1),
ERR_MESSAGE_2,
output_string,
)
};
debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
Ok(output_with_status)
})
cx.background_spawn(run_command_limited(working_dir, input.command))
}
}
const LIMIT: usize = 16 * 1024;
async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
// Add 2>&1 to merge stderr into stdout for proper interleaving.
let command = format!("({}) 2>&1", command);
let mut cmd = new_smol_command("bash")
.arg("-c")
.arg(&command)
.current_dir(working_dir)
.stdout(std::process::Stdio::piped())
.spawn()
.context("Failed to execute bash command")?;
// Capture stdout with a limit
let stdout = cmd.stdout.take().unwrap();
let mut reader = BufReader::new(stdout);
// Read one more byte to determine whether the output was truncated
let mut buffer = vec![0; LIMIT + 1];
let mut bytes_read = 0;
// Read until we reach the limit
loop {
let read = reader.read(&mut buffer[bytes_read..]).await?;
if read == 0 {
break;
}
bytes_read += read;
if bytes_read > LIMIT {
bytes_read = LIMIT + 1;
break;
}
}
// Repeatedly fill the output reader's buffer without copying it.
loop {
let skipped_bytes = reader.fill_buf().await?;
if skipped_bytes.is_empty() {
break;
}
let skipped_bytes_len = skipped_bytes.len();
reader.consume_unpin(skipped_bytes_len);
}
let output_bytes = &buffer[..bytes_read.min(LIMIT)];
let status = cmd.status().await.context("Failed to get command status")?;
let output_string = if bytes_read > LIMIT {
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
// multi-byte characters.
let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
let until_last_line = &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())];
let output_string = String::from_utf8_lossy(until_last_line);
format!(
"Command output too long. The first {} bytes:\n\n{}",
output_string.len(),
output_block(&output_string),
)
} else {
output_block(&String::from_utf8_lossy(&output_bytes))
};
let output_with_status = if status.success() {
if output_string.is_empty() {
"Command executed successfully.".to_string()
} else {
output_string.to_string()
}
} else {
format!(
"Command failed with exit code {}\n\n{}",
status.code().unwrap_or(-1),
output_string,
)
};
Ok(output_with_status)
}
fn output_block(output: &str) -> String {
format!(
"```\n{}{}```",
output,
if output.ends_with('\n') { "" } else { "\n" }
)
}
#[cfg(test)]
#[cfg(not(windows))]
mod tests {
use gpui::TestAppContext;
use super::*;
#[gpui::test]
async fn test_run_command_simple(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let result =
run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\nHello, World!\n```");
}
#[gpui::test]
async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let command =
"echo 'stdout 1' && echo 'stderr 1' >&2 && echo 'stdout 2' && echo 'stderr 2' >&2";
let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
"```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
);
}
#[gpui::test]
async fn test_multiple_output_reads(cx: &mut TestAppContext) {
cx.executor().allow_parking();
// Command with multiple outputs that might require multiple reads
let result = run_command_limited(
Path::new(".").into(),
"echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
}
#[gpui::test]
async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}';", "X".repeat(LIMIT * 2));
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
// Output should be exactly the limit
assert_eq!(content_length, LIMIT);
}
#[gpui::test]
async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
assert!(content_length <= LIMIT);
}
}