agent: Truncate bash tool output (#28291)
The bash tool will now truncate its output to 8192 bytes (or the last newline before that). We also added a global limit for any tool that produces a clearly large output that wouldn't fit the context window. Release Notes: - agent: Truncate bash tool output --------- Co-authored-by: Michael Sloan <mgsloan@gmail.com>
This commit is contained in:
parent
1774cad933
commit
85c5d8af3a
4 changed files with 164 additions and 16 deletions
|
@ -1487,6 +1487,7 @@ impl Thread {
|
|||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
output,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.emit(ThreadEvent::ToolFinished {
|
||||
|
@ -1831,7 +1832,7 @@ impl Thread {
|
|||
));
|
||||
|
||||
self.tool_use
|
||||
.insert_tool_output(tool_use_id.clone(), tool_name, err);
|
||||
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
|
||||
|
||||
cx.emit(ThreadEvent::ToolFinished {
|
||||
tool_use_id,
|
||||
|
|
|
@ -7,10 +7,11 @@ use futures::FutureExt as _;
|
|||
use futures::future::Shared;
|
||||
use gpui::{App, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
|
||||
LanguageModelToolUseId, MessageContent, Role,
|
||||
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
|
||||
};
|
||||
use ui::IconName;
|
||||
use util::truncate_lines_to_byte_limit;
|
||||
|
||||
use crate::thread::MessageId;
|
||||
use crate::thread_store::SerializedMessage;
|
||||
|
@ -331,9 +332,32 @@ impl ToolUseState {
|
|||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
output: Result<String>,
|
||||
cx: &App,
|
||||
) -> Option<PendingToolUse> {
|
||||
match output {
|
||||
Ok(tool_result) => {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
|
||||
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
|
||||
|
||||
// Protect from clearly large output
|
||||
let tool_output_limit = model_registry
|
||||
.default_model()
|
||||
.map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
let tool_result = if tool_result.len() <= tool_output_limit {
|
||||
tool_result
|
||||
} else {
|
||||
let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
|
||||
|
||||
format!(
|
||||
"Tool result too long. The first {} bytes:\n\n{}",
|
||||
truncated.len(),
|
||||
truncated
|
||||
)
|
||||
};
|
||||
|
||||
self.tool_results.insert(
|
||||
tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool};
|
||||
use futures::io::BufReader;
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
|
@ -125,29 +127,90 @@ impl Tool for BashTool {
|
|||
// Add 2>&1 to merge stderr into stdout for proper interleaving.
|
||||
let command = format!("({}) 2>&1", input.command);
|
||||
|
||||
let output = new_smol_command("bash")
|
||||
let mut cmd = new_smol_command("bash")
|
||||
.arg("-c")
|
||||
.arg(&command)
|
||||
.current_dir(working_dir)
|
||||
.output()
|
||||
.await
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.context("Failed to execute bash command")?;
|
||||
|
||||
let output_string = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
// Capture stdout with a limit
|
||||
let stdout = cmd.stdout.take().unwrap();
|
||||
let mut reader = BufReader::new(stdout);
|
||||
|
||||
if output.status.success() {
|
||||
const MESSAGE_1: &str = "Command output too long. The first ";
|
||||
const MESSAGE_2: &str = " bytes:\n\n";
|
||||
const ERR_MESSAGE_1: &str = "Command failed with exit code ";
|
||||
const ERR_MESSAGE_2: &str = "\n\n";
|
||||
|
||||
const STDOUT_LIMIT: usize = 8192;
|
||||
|
||||
const LIMIT: usize = STDOUT_LIMIT
|
||||
- (MESSAGE_1.len()
|
||||
+ (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
|
||||
+ MESSAGE_2.len()
|
||||
+ ERR_MESSAGE_1.len()
|
||||
+ 3 // status code
|
||||
+ ERR_MESSAGE_2.len());
|
||||
|
||||
// Read one more byte to determine whether the output was truncated
|
||||
let mut buffer = vec![0; LIMIT + 1];
|
||||
let bytes_read = reader.read(&mut buffer).await?;
|
||||
|
||||
// Repeatedly fill the output reader's buffer without copying it.
|
||||
loop {
|
||||
let skipped_bytes = reader.fill_buf().await?;
|
||||
if skipped_bytes.is_empty() {
|
||||
break;
|
||||
}
|
||||
let skipped_bytes_len = skipped_bytes.len();
|
||||
reader.consume_unpin(skipped_bytes_len);
|
||||
}
|
||||
|
||||
let output_bytes = &buffer[..bytes_read];
|
||||
|
||||
// Let the process continue running
|
||||
let status = cmd.status().await.context("Failed to get command status")?;
|
||||
|
||||
let output_string = if bytes_read > LIMIT {
|
||||
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
|
||||
// multi-byte characters.
|
||||
let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
|
||||
let output_string = String::from_utf8_lossy(
|
||||
&output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
|
||||
);
|
||||
|
||||
format!(
|
||||
"{}{}{}{}",
|
||||
MESSAGE_1,
|
||||
output_string.len(),
|
||||
MESSAGE_2,
|
||||
output_string
|
||||
)
|
||||
} else {
|
||||
String::from_utf8_lossy(&output_bytes).into()
|
||||
};
|
||||
|
||||
let output_with_status = if status.success() {
|
||||
if output_string.is_empty() {
|
||||
Ok("Command executed successfully.".to_string())
|
||||
"Command executed successfully.".to_string()
|
||||
} else {
|
||||
Ok(output_string)
|
||||
output_string.to_string()
|
||||
}
|
||||
} else {
|
||||
Ok(format!(
|
||||
"Command failed with exit code {}\n{}",
|
||||
output.status.code().unwrap_or(-1),
|
||||
&output_string
|
||||
))
|
||||
}
|
||||
format!(
|
||||
"{}{}{}{}",
|
||||
ERR_MESSAGE_1,
|
||||
status.code().unwrap_or(-1),
|
||||
ERR_MESSAGE_2,
|
||||
output_string,
|
||||
)
|
||||
};
|
||||
|
||||
debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
|
||||
|
||||
Ok(output_with_status)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -145,6 +145,66 @@ pub fn truncate_lines_and_trailoff(s: &str, max_lines: usize) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
/// Truncates the string at a character boundary, such that the result is less than `max_bytes` in
|
||||
/// length.
|
||||
pub fn truncate_to_byte_limit(s: &str, max_bytes: usize) -> &str {
|
||||
if s.len() < max_bytes {
|
||||
return s;
|
||||
}
|
||||
|
||||
for i in (0..max_bytes).rev() {
|
||||
if s.is_char_boundary(i) {
|
||||
return &s[..i];
|
||||
}
|
||||
}
|
||||
|
||||
""
|
||||
}
|
||||
|
||||
/// Takes a prefix of complete lines which fit within the byte limit. If the first line is longer
|
||||
/// than the limit, truncates at a character boundary.
|
||||
pub fn truncate_lines_to_byte_limit(s: &str, max_bytes: usize) -> &str {
|
||||
if s.len() < max_bytes {
|
||||
return s;
|
||||
}
|
||||
|
||||
for i in (0..max_bytes).rev() {
|
||||
if s.is_char_boundary(i) {
|
||||
if s.as_bytes()[i] == b'\n' {
|
||||
// Since the i-th character is \n, valid to slice at i + 1.
|
||||
return &s[..i + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
truncate_to_byte_limit(s, max_bytes)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_lines_to_byte_limit() {
|
||||
let text = "Line 1\nLine 2\nLine 3\nLine 4";
|
||||
|
||||
// Limit that includes all lines
|
||||
assert_eq!(truncate_lines_to_byte_limit(text, 100), text);
|
||||
|
||||
// Exactly the first line
|
||||
assert_eq!(truncate_lines_to_byte_limit(text, 7), "Line 1\n");
|
||||
|
||||
// Limit between lines
|
||||
assert_eq!(truncate_lines_to_byte_limit(text, 13), "Line 1\n");
|
||||
assert_eq!(truncate_lines_to_byte_limit(text, 20), "Line 1\nLine 2\n");
|
||||
|
||||
// Limit before first newline
|
||||
assert_eq!(truncate_lines_to_byte_limit(text, 6), "Line ");
|
||||
|
||||
// Test with non-ASCII characters
|
||||
let text_utf8 = "Line 1\nLíne 2\nLine 3";
|
||||
assert_eq!(
|
||||
truncate_lines_to_byte_limit(text_utf8, 15),
|
||||
"Line 1\nLíne 2\n"
|
||||
);
|
||||
}
|
||||
|
||||
pub fn post_inc<T: From<u8> + AddAssign<T> + Copy>(value: &mut T) -> T {
|
||||
let prev = *value;
|
||||
*value += T::from(1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue