Fix bash tool output (#28391)

This commit is contained in:
Agus Zubiaga 2025-04-09 08:20:24 -06:00 committed by GitHub
parent 3a8fe4d973
commit 1cb4f8288d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool}; use assistant_tool::{ActionLog, Tool};
use futures::io::BufReader; use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt}; use futures::{AsyncBufReadExt, AsyncReadExt};
use gpui::{App, Entity, Task}; use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project; use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
@ -123,108 +123,184 @@ impl Tool for BashTool {
worktree.read(cx).abs_path() worktree.read(cx).abs_path()
}; };
cx.spawn(async move |_| { cx.background_spawn(run_command_limited(working_dir, input.command))
// Add 2>&1 to merge stderr into stdout for proper interleaving. }
let command = format!("({}) 2>&1", input.command); }
let mut cmd = new_smol_command("bash") const LIMIT: usize = 16 * 1024;
.arg("-c")
.arg(&command) async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
.current_dir(working_dir) // Add 2>&1 to merge stderr into stdout for proper interleaving.
.stdout(std::process::Stdio::piped()) let command = format!("({}) 2>&1", command);
.spawn()
.context("Failed to execute bash command")?; let mut cmd = new_smol_command("bash")
.arg("-c")
// Capture stdout with a limit .arg(&command)
let stdout = cmd.stdout.take().unwrap(); .current_dir(working_dir)
let mut reader = BufReader::new(stdout); .stdout(std::process::Stdio::piped())
.spawn()
const MESSAGE_1: &str = "Command output too long. The first "; .context("Failed to execute bash command")?;
const MESSAGE_2: &str = " bytes:\n\n";
const ERR_MESSAGE_1: &str = "Command failed with exit code "; // Capture stdout with a limit
const ERR_MESSAGE_2: &str = "\n\n"; let stdout = cmd.stdout.take().unwrap();
let mut reader = BufReader::new(stdout);
const STDOUT_LIMIT: usize = 8192;
// Read one more byte to determine whether the output was truncated
const LIMIT: usize = STDOUT_LIMIT let mut buffer = vec![0; LIMIT + 1];
- (MESSAGE_1.len() let mut bytes_read = 0;
+ (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
+ MESSAGE_2.len() // Read until we reach the limit
+ ERR_MESSAGE_1.len() loop {
+ 3 // status code let read = reader.read(&mut buffer[bytes_read..]).await?;
+ ERR_MESSAGE_2.len()); if read == 0 {
break;
// Read one more byte to determine whether the output was truncated }
let mut buffer = vec![0; LIMIT + 1];
let mut bytes_read = 0; bytes_read += read;
if bytes_read > LIMIT {
// Read until we reach the limit bytes_read = LIMIT + 1;
loop { break;
let read = reader.read(&mut buffer).await?; }
if read == 0 { }
break;
} // Repeatedly fill the output reader's buffer without copying it.
loop {
bytes_read += read; let skipped_bytes = reader.fill_buf().await?;
if bytes_read > LIMIT { if skipped_bytes.is_empty() {
bytes_read = LIMIT + 1; break;
break; }
} let skipped_bytes_len = skipped_bytes.len();
} reader.consume_unpin(skipped_bytes_len);
}
// Repeatedly fill the output reader's buffer without copying it.
loop { let output_bytes = &buffer[..bytes_read.min(LIMIT)];
let skipped_bytes = reader.fill_buf().await?;
if skipped_bytes.is_empty() { let status = cmd.status().await.context("Failed to get command status")?;
break;
} let output_string = if bytes_read > LIMIT {
let skipped_bytes_len = skipped_bytes.len(); // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
reader.consume_unpin(skipped_bytes_len); // multi-byte characters.
} let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
let until_last_line = &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())];
let output_bytes = &buffer[..bytes_read]; let output_string = String::from_utf8_lossy(until_last_line);
// Let the process continue running format!(
let status = cmd.status().await.context("Failed to get command status")?; "Command output too long. The first {} bytes:\n\n{}",
output_string.len(),
let output_string = if bytes_read > LIMIT { output_block(&output_string),
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in )
// multi-byte characters. } else {
let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n'); output_block(&String::from_utf8_lossy(&output_bytes))
let output_string = String::from_utf8_lossy( };
&output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
); let output_with_status = if status.success() {
if output_string.is_empty() {
format!( "Command executed successfully.".to_string()
"{}{}{}{}", } else {
MESSAGE_1, output_string.to_string()
output_string.len(), }
MESSAGE_2, } else {
output_string format!(
) "Command failed with exit code {}\n\n{}",
} else { status.code().unwrap_or(-1),
String::from_utf8_lossy(&output_bytes).into() output_string,
}; )
};
let output_with_status = if status.success() {
if output_string.is_empty() { Ok(output_with_status)
"Command executed successfully.".to_string() }
} else {
output_string.to_string() fn output_block(output: &str) -> String {
} format!(
} else { "```\n{}{}```",
format!( output,
"{}{}{}{}", if output.ends_with('\n') { "" } else { "\n" }
ERR_MESSAGE_1, )
status.code().unwrap_or(-1), }
ERR_MESSAGE_2,
output_string, #[cfg(test)]
) #[cfg(not(windows))]
}; mod tests {
use gpui::TestAppContext;
debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
use super::*;
Ok(output_with_status)
}) #[gpui::test]
async fn test_run_command_simple(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let result =
run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\nHello, World!\n```");
}
#[gpui::test]
async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let command =
"echo 'stdout 1' && echo 'stderr 1' >&2 && echo 'stdout 2' && echo 'stderr 2' >&2";
let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
"```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
);
}
#[gpui::test]
async fn test_multiple_output_reads(cx: &mut TestAppContext) {
cx.executor().allow_parking();
// Command with multiple outputs that might require multiple reads
let result = run_command_limited(
Path::new(".").into(),
"echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
}
#[gpui::test]
async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}';", "X".repeat(LIMIT * 2));
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
// Output should be exactly the limit
assert_eq!(content_length, LIMIT);
}
#[gpui::test]
async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
assert!(content_length <= LIMIT);
} }
} }