From 1cb4f8288db7f1e3d37c396fa6112ed1beba2b07 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Wed, 9 Apr 2025 08:20:24 -0600 Subject: [PATCH] Fix bash tool output (#28391) --- crates/assistant_tools/src/bash_tool.rs | 284 +++++++++++++++--------- 1 file changed, 180 insertions(+), 104 deletions(-) diff --git a/crates/assistant_tools/src/bash_tool.rs b/crates/assistant_tools/src/bash_tool.rs index 23cbe3ca29..32f5a3d4e4 100644 --- a/crates/assistant_tools/src/bash_tool.rs +++ b/crates/assistant_tools/src/bash_tool.rs @@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ActionLog, Tool}; use futures::io::BufReader; use futures::{AsyncBufReadExt, AsyncReadExt}; -use gpui::{App, Entity, Task}; +use gpui::{App, AppContext, Entity, Task}; use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use project::Project; use schemars::JsonSchema; @@ -123,108 +123,184 @@ impl Tool for BashTool { worktree.read(cx).abs_path() }; - cx.spawn(async move |_| { - // Add 2>&1 to merge stderr into stdout for proper interleaving. - let command = format!("({}) 2>&1", input.command); - - let mut cmd = new_smol_command("bash") - .arg("-c") - .arg(&command) - .current_dir(working_dir) - .stdout(std::process::Stdio::piped()) - .spawn() - .context("Failed to execute bash command")?; - - // Capture stdout with a limit - let stdout = cmd.stdout.take().unwrap(); - let mut reader = BufReader::new(stdout); - - const MESSAGE_1: &str = "Command output too long. The first "; - const MESSAGE_2: &str = " bytes:\n\n"; - const ERR_MESSAGE_1: &str = "Command failed with exit code "; - const ERR_MESSAGE_2: &str = "\n\n"; - - const STDOUT_LIMIT: usize = 8192; - - const LIMIT: usize = STDOUT_LIMIT - - (MESSAGE_1.len() - + (STDOUT_LIMIT.ilog10() as usize + 1) // byte count - + MESSAGE_2.len() - + ERR_MESSAGE_1.len() - + 3 // status code - + ERR_MESSAGE_2.len()); - - // Read one more byte to determine whether the output was truncated - let mut buffer = vec![0; LIMIT + 1]; - let mut bytes_read = 0; - - // Read until we reach the limit - loop { - let read = reader.read(&mut buffer).await?; - if read == 0 { - break; - } - - bytes_read += read; - if bytes_read > LIMIT { - bytes_read = LIMIT + 1; - break; - } - } - - // Repeatedly fill the output reader's buffer without copying it. - loop { - let skipped_bytes = reader.fill_buf().await?; - if skipped_bytes.is_empty() { - break; - } - let skipped_bytes_len = skipped_bytes.len(); - reader.consume_unpin(skipped_bytes_len); - } - - let output_bytes = &buffer[..bytes_read]; - - // Let the process continue running - let status = cmd.status().await.context("Failed to get command status")?; - - let output_string = if bytes_read > LIMIT { - // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in - // multi-byte characters. - let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n'); - let output_string = String::from_utf8_lossy( - &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())], - ); - - format!( - "{}{}{}{}", - MESSAGE_1, - output_string.len(), - MESSAGE_2, - output_string - ) - } else { - String::from_utf8_lossy(&output_bytes).into() - }; - - let output_with_status = if status.success() { - if output_string.is_empty() { - "Command executed successfully.".to_string() - } else { - output_string.to_string() - } - } else { - format!( - "{}{}{}{}", - ERR_MESSAGE_1, - status.code().unwrap_or(-1), - ERR_MESSAGE_2, - output_string, - ) - }; - - debug_assert!(output_with_status.len() <= STDOUT_LIMIT); - - Ok(output_with_status) - }) + cx.background_spawn(run_command_limited(working_dir, input.command)) + } +} + +const LIMIT: usize = 16 * 1024; + +async fn run_command_limited(working_dir: Arc, command: String) -> Result { + // Add 2>&1 to merge stderr into stdout for proper interleaving. + let command = format!("({}) 2>&1", command); + + let mut cmd = new_smol_command("bash") + .arg("-c") + .arg(&command) + .current_dir(working_dir) + .stdout(std::process::Stdio::piped()) + .spawn() + .context("Failed to execute bash command")?; + + // Capture stdout with a limit + let stdout = cmd.stdout.take().unwrap(); + let mut reader = BufReader::new(stdout); + + // Read one more byte to determine whether the output was truncated + let mut buffer = vec![0; LIMIT + 1]; + let mut bytes_read = 0; + + // Read until we reach the limit + loop { + let read = reader.read(&mut buffer[bytes_read..]).await?; + if read == 0 { + break; + } + + bytes_read += read; + if bytes_read > LIMIT { + bytes_read = LIMIT + 1; + break; + } + } + + // Repeatedly fill the output reader's buffer without copying it. + loop { + let skipped_bytes = reader.fill_buf().await?; + if skipped_bytes.is_empty() { + break; + } + let skipped_bytes_len = skipped_bytes.len(); + reader.consume_unpin(skipped_bytes_len); + } + + let output_bytes = &buffer[..bytes_read.min(LIMIT)]; + + let status = cmd.status().await.context("Failed to get command status")?; + + let output_string = if bytes_read > LIMIT { + // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in + // multi-byte characters. + let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n'); + let until_last_line = &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())]; + let output_string = String::from_utf8_lossy(until_last_line); + + format!( + "Command output too long. The first {} bytes:\n\n{}", + output_string.len(), + output_block(&output_string), + ) + } else { + output_block(&String::from_utf8_lossy(&output_bytes)) + }; + + let output_with_status = if status.success() { + if output_string.is_empty() { + "Command executed successfully.".to_string() + } else { + output_string.to_string() + } + } else { + format!( + "Command failed with exit code {}\n\n{}", + status.code().unwrap_or(-1), + output_string, + ) + }; + + Ok(output_with_status) +} + +fn output_block(output: &str) -> String { + format!( + "```\n{}{}```", + output, + if output.ends_with('\n') { "" } else { "\n" } + ) +} + +#[cfg(test)] +#[cfg(not(windows))] +mod tests { + use gpui::TestAppContext; + + use super::*; + + #[gpui::test] + async fn test_run_command_simple(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + let result = + run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "```\nHello, World!\n```"); + } + + #[gpui::test] + async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + let command = + "echo 'stdout 1' && echo 'stderr 1' >&2 && echo 'stdout 2' && echo 'stderr 2' >&2"; + let result = run_command_limited(Path::new(".").into(), command.to_string()).await; + + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + "```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```" + ); + } + + #[gpui::test] + async fn test_multiple_output_reads(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + // Command with multiple outputs that might require multiple reads + let result = run_command_limited( + Path::new(".").into(), + "echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(), + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "```\n1\n2\n3\n```"); + } + + #[gpui::test] + async fn test_output_truncation_single_line(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + let cmd = format!("echo '{}';", "X".repeat(LIMIT * 2)); + + let result = run_command_limited(Path::new(".").into(), cmd).await; + + assert!(result.is_ok()); + let output = result.unwrap(); + + let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0); + let content_end = output.rfind("\n```").unwrap_or(output.len()); + let content_length = content_end - content_start; + + // Output should be exactly the limit + assert_eq!(content_length, LIMIT); + } + + #[gpui::test] + async fn test_output_truncation_multiline(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160); + let result = run_command_limited(Path::new(".").into(), cmd).await; + + assert!(result.is_ok()); + let output = result.unwrap(); + + assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n")); + + let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0); + let content_end = output.rfind("\n```").unwrap_or(output.len()); + let content_length = content_end - content_start; + + assert!(content_length <= LIMIT); } }