agent: Enrich grep
tool output with syntax information (#29601)
The `grep` tool used to include 4 lines of context around the match, but the lines included would often be unhelpful. This PR improves this behavior by using the range of the parent syntax node that contains the full line(s) matched. The match headers will also now include symbol breadcrumbs so that the model can already gather code structure before/without reading files. ````md ### impl GitRepository for RealGitRepository › fn compare_checkpoints › L1278-1284 ```rust let result = git .run(&[ "diff-tree", "--quiet", &left.commit_sha.to_string(), &right.commit_sha.to_string(), ]) ``` ```` This positively impacts the `add_arg_to_trait_method` eval example with better diff output, fewer tool failures, and reduced total turns. Note: We have some plans to use a an "elision" approach where we would combine all matches for a given file, skipping lines between them while keeping symbol declaration lines. The theory is that this would be map more closely to the expected input for edits. For now, this PR is a significant improvement. Release Notes: - Agent: Enrich `grep` tool output with syntax information
This commit is contained in:
parent
5507958327
commit
fd17f2d8ae
3 changed files with 411 additions and 86 deletions
|
@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
|
|||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::StreamExt;
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::OffsetRangeExt;
|
||||
use language::{OffsetRangeExt, ParseStatus, Point};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{
|
||||
Project,
|
||||
|
@ -13,6 +13,7 @@ use schemars::JsonSchema;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::{cmp, fmt::Write, sync::Arc};
|
||||
use ui::IconName;
|
||||
use util::RangeExt;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
use util::paths::PathMatcher;
|
||||
|
||||
|
@ -102,6 +103,7 @@ impl Tool for GrepTool {
|
|||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
const CONTEXT_LINES: u32 = 2;
|
||||
const MAX_ANCESTOR_LINES: u32 = 10;
|
||||
|
||||
let input = match serde_json::from_value::<GrepToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
|
@ -140,7 +142,7 @@ impl Tool for GrepTool {
|
|||
|
||||
let results = project.update(cx, |project, cx| project.search(query, cx));
|
||||
|
||||
cx.spawn(async move|cx| {
|
||||
cx.spawn(async move |cx| {
|
||||
futures::pin_mut!(results);
|
||||
|
||||
let mut output = String::new();
|
||||
|
@ -148,68 +150,113 @@ impl Tool for GrepTool {
|
|||
let mut matches_found = 0;
|
||||
let mut has_more_matches = false;
|
||||
|
||||
while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
|
||||
'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
|
||||
if ranges.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
|
||||
if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
|
||||
let mut file_header_written = false;
|
||||
let mut ranges = ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
let mut point_range = range.to_point(buffer);
|
||||
point_range.start.row =
|
||||
point_range.start.row.saturating_sub(CONTEXT_LINES);
|
||||
point_range.start.column = 0;
|
||||
point_range.end.row = cmp::min(
|
||||
buffer.max_point().row,
|
||||
point_range.end.row + CONTEXT_LINES,
|
||||
);
|
||||
point_range.end.column = buffer.line_len(point_range.end.row);
|
||||
point_range
|
||||
})
|
||||
.peekable();
|
||||
let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
|
||||
(buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
|
||||
})? else {
|
||||
continue;
|
||||
};
|
||||
|
||||
while let Some(mut range) = ranges.next() {
|
||||
if skips_remaining > 0 {
|
||||
skips_remaining -= 1;
|
||||
continue;
|
||||
|
||||
while *parse_status.borrow() != ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
}
|
||||
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
let mut ranges = ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
let matched = range.to_point(&snapshot);
|
||||
let matched_end_line_len = snapshot.line_len(matched.end.row);
|
||||
let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len);
|
||||
let symbols = snapshot.symbols_containing(matched.start, None);
|
||||
|
||||
if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) {
|
||||
let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot);
|
||||
let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES);
|
||||
let end_col = snapshot.line_len(end_row);
|
||||
let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col);
|
||||
|
||||
if capped_ancestor_range.contains_inclusive(&full_lines) {
|
||||
return (capped_ancestor_range, Some(full_ancestor_range), symbols)
|
||||
}
|
||||
}
|
||||
|
||||
// We'd already found a full page of matches, and we just found one more.
|
||||
if matches_found >= RESULTS_PER_PAGE {
|
||||
has_more_matches = true;
|
||||
return Ok(());
|
||||
}
|
||||
let mut matched = matched;
|
||||
matched.start.column = 0;
|
||||
matched.start.row =
|
||||
matched.start.row.saturating_sub(CONTEXT_LINES);
|
||||
matched.end.row = cmp::min(
|
||||
snapshot.max_point().row,
|
||||
matched.end.row + CONTEXT_LINES,
|
||||
);
|
||||
matched.end.column = snapshot.line_len(matched.end.row);
|
||||
|
||||
while let Some(next_range) = ranges.peek() {
|
||||
if range.end.row >= next_range.start.row {
|
||||
range.end = next_range.end;
|
||||
ranges.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(matched, None, symbols)
|
||||
})
|
||||
.peekable();
|
||||
|
||||
if !file_header_written {
|
||||
writeln!(output, "\n## Matches in {}", path.display())?;
|
||||
file_header_written = true;
|
||||
}
|
||||
let mut file_header_written = false;
|
||||
|
||||
let start_line = range.start.row + 1;
|
||||
let end_line = range.end.row + 1;
|
||||
writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
|
||||
output.extend(buffer.text_for_range(range));
|
||||
output.push_str("\n```\n");
|
||||
while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){
|
||||
if skips_remaining > 0 {
|
||||
skips_remaining -= 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
matches_found += 1;
|
||||
// We'd already found a full page of matches, and we just found one more.
|
||||
if matches_found >= RESULTS_PER_PAGE {
|
||||
has_more_matches = true;
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
while let Some((next_range, _, _)) = ranges.peek() {
|
||||
if range.end.row >= next_range.start.row {
|
||||
range.end = next_range.end;
|
||||
ranges.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})??;
|
||||
if !file_header_written {
|
||||
writeln!(output, "\n## Matches in {}", path.display())?;
|
||||
file_header_written = true;
|
||||
}
|
||||
|
||||
let end_row = range.end.row;
|
||||
output.push_str("\n### ");
|
||||
|
||||
if let Some(parent_symbols) = &parent_symbols {
|
||||
for symbol in parent_symbols {
|
||||
write!(output, "{} › ", symbol.text)?;
|
||||
}
|
||||
}
|
||||
|
||||
if range.start.row == end_row {
|
||||
writeln!(output, "L{}", range.start.row + 1)?;
|
||||
} else {
|
||||
writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?;
|
||||
}
|
||||
|
||||
output.push_str("```\n");
|
||||
output.extend(snapshot.text_for_range(range));
|
||||
output.push_str("\n```\n");
|
||||
|
||||
if let Some(ancestor_range) = ancestor_range {
|
||||
if end_row < ancestor_range.end.row {
|
||||
let remaining_lines = ancestor_range.end.row - end_row;
|
||||
writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
|
||||
}
|
||||
}
|
||||
|
||||
matches_found += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if matches_found == 0 {
|
||||
|
@ -233,13 +280,16 @@ mod tests {
|
|||
use super::*;
|
||||
use assistant_tool::Tool;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageMatcher};
|
||||
use project::{FakeFs, Project};
|
||||
use settings::SettingsStore;
|
||||
use unindent::Unindent;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor().clone());
|
||||
fs.insert_tree(
|
||||
|
@ -327,6 +377,7 @@ mod tests {
|
|||
#[gpui::test]
|
||||
async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor().clone());
|
||||
fs.insert_tree(
|
||||
|
@ -401,6 +452,290 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
/// Helper function to set up a syntax test environment
|
||||
async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity<Project> {
|
||||
use unindent::Unindent;
|
||||
init_test(cx);
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.executor().clone());
|
||||
|
||||
// Create test file with syntax structures
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
serde_json::json!({
|
||||
"test_syntax.rs": r#"
|
||||
fn top_level_function() {
|
||||
println!("This is at the top level");
|
||||
}
|
||||
|
||||
mod feature_module {
|
||||
pub mod nested_module {
|
||||
pub fn nested_function(
|
||||
first_arg: String,
|
||||
second_arg: i32,
|
||||
) {
|
||||
println!("Function in nested module");
|
||||
println!("{first_arg}");
|
||||
println!("{second_arg}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MyStruct {
|
||||
field1: String,
|
||||
field2: i32,
|
||||
}
|
||||
|
||||
impl MyStruct {
|
||||
fn method_with_block() {
|
||||
let condition = true;
|
||||
if condition {
|
||||
println!("Inside if block");
|
||||
}
|
||||
}
|
||||
|
||||
fn long_function() {
|
||||
println!("Line 1");
|
||||
println!("Line 2");
|
||||
println!("Line 3");
|
||||
println!("Line 4");
|
||||
println!("Line 5");
|
||||
println!("Line 6");
|
||||
println!("Line 7");
|
||||
println!("Line 8");
|
||||
println!("Line 9");
|
||||
println!("Line 10");
|
||||
println!("Line 11");
|
||||
println!("Line 12");
|
||||
}
|
||||
}
|
||||
|
||||
trait Processor {
|
||||
fn process(&self, input: &str) -> String;
|
||||
}
|
||||
|
||||
impl Processor for MyStruct {
|
||||
fn process(&self, input: &str) -> String {
|
||||
format!("Processed: {}", input)
|
||||
}
|
||||
}
|
||||
"#.unindent().trim(),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
|
||||
project.update(cx, |project, _cx| {
|
||||
project.languages().add(rust_lang().into())
|
||||
});
|
||||
|
||||
project
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_grep_top_level_function(cx: &mut TestAppContext) {
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line at the top level of the file
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "This is at the top level".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
Found 1 matches:
|
||||
|
||||
## Matches in root/test_syntax.rs
|
||||
|
||||
### fn top_level_function › L1-3
|
||||
```
|
||||
fn top_level_function() {
|
||||
println!("This is at the top level");
|
||||
}
|
||||
```
|
||||
"#
|
||||
.unindent();
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_grep_function_body(cx: &mut TestAppContext) {
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line inside a function body
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "Function in nested module".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
Found 1 matches:
|
||||
|
||||
## Matches in root/test_syntax.rs
|
||||
|
||||
### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14
|
||||
```
|
||||
) {
|
||||
println!("Function in nested module");
|
||||
println!("{first_arg}");
|
||||
println!("{second_arg}");
|
||||
}
|
||||
```
|
||||
"#
|
||||
.unindent();
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_grep_function_args_and_body(cx: &mut TestAppContext) {
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line with a function argument
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "second_arg".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
Found 1 matches:
|
||||
|
||||
## Matches in root/test_syntax.rs
|
||||
|
||||
### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14
|
||||
```
|
||||
pub fn nested_function(
|
||||
first_arg: String,
|
||||
second_arg: i32,
|
||||
) {
|
||||
println!("Function in nested module");
|
||||
println!("{first_arg}");
|
||||
println!("{second_arg}");
|
||||
}
|
||||
```
|
||||
"#
|
||||
.unindent();
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_grep_if_block(cx: &mut TestAppContext) {
|
||||
use unindent::Unindent;
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line inside an if block
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "Inside if block".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
Found 1 matches:
|
||||
|
||||
## Matches in root/test_syntax.rs
|
||||
|
||||
### impl MyStruct › fn method_with_block › L26-28
|
||||
```
|
||||
if condition {
|
||||
println!("Inside if block");
|
||||
}
|
||||
```
|
||||
"#
|
||||
.unindent();
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_grep_long_function_top(cx: &mut TestAppContext) {
|
||||
use unindent::Unindent;
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line in the middle of a long function - should show message about remaining lines
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "Line 5".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
Found 1 matches:
|
||||
|
||||
## Matches in root/test_syntax.rs
|
||||
|
||||
### impl MyStruct › fn long_function › L31-41
|
||||
```
|
||||
fn long_function() {
|
||||
println!("Line 1");
|
||||
println!("Line 2");
|
||||
println!("Line 3");
|
||||
println!("Line 4");
|
||||
println!("Line 5");
|
||||
println!("Line 6");
|
||||
println!("Line 7");
|
||||
println!("Line 8");
|
||||
println!("Line 9");
|
||||
println!("Line 10");
|
||||
```
|
||||
|
||||
3 lines remaining in ancestor node. Read the file to see all.
|
||||
"#
|
||||
.unindent();
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_grep_long_function_bottom(cx: &mut TestAppContext) {
|
||||
use unindent::Unindent;
|
||||
let project = setup_syntax_test(cx).await;
|
||||
|
||||
// Test: Line in the long function
|
||||
let input = serde_json::to_value(GrepToolInput {
|
||||
regex: "Line 12".to_string(),
|
||||
include_pattern: Some("**/*.rs".to_string()),
|
||||
offset: 0,
|
||||
case_sensitive: false,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let result = run_grep_tool(input, project.clone(), cx).await;
|
||||
let expected = r#"
|
||||
Found 1 matches:
|
||||
|
||||
## Matches in root/test_syntax.rs
|
||||
|
||||
### impl MyStruct › fn long_function › L41-45
|
||||
```
|
||||
println!("Line 10");
|
||||
println!("Line 11");
|
||||
println!("Line 12");
|
||||
}
|
||||
}
|
||||
```
|
||||
"#
|
||||
.unindent();
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
async fn run_grep_tool(
|
||||
input: serde_json::Value,
|
||||
project: Entity<Project>,
|
||||
|
@ -411,7 +746,13 @@ mod tests {
|
|||
let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
|
||||
|
||||
match task.output.await {
|
||||
Ok(result) => result,
|
||||
Ok(result) => {
|
||||
if cfg!(windows) {
|
||||
result.replace("root\\", "root/")
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("Failed to run grep tool: {}", e),
|
||||
}
|
||||
}
|
||||
|
@ -424,4 +765,20 @@ mod tests {
|
|||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -387,6 +387,7 @@ impl Response {
|
|||
cx.assert_some(result, format!("called `{}`", tool_name))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
|
||||
self.messages.iter().flat_map(|msg| &msg.tool_use)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use std::{collections::HashSet, path::Path};
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tools::{CreateFileToolInput, EditFileToolInput, ReadFileToolInput};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer};
|
||||
|
@ -32,39 +31,7 @@ impl Example for AddArgToTraitMethod {
|
|||
"#
|
||||
));
|
||||
|
||||
let response = cx.run_to_end().await?;
|
||||
|
||||
// Reads files before it edits them
|
||||
|
||||
let mut read_files = HashSet::new();
|
||||
|
||||
for tool_use in response.tool_uses() {
|
||||
match tool_use.name.as_str() {
|
||||
"read_file" => {
|
||||
if let Ok(input) = tool_use.parse_input::<ReadFileToolInput>() {
|
||||
read_files.insert(input.path);
|
||||
}
|
||||
}
|
||||
"create_file" => {
|
||||
if let Ok(input) = tool_use.parse_input::<CreateFileToolInput>() {
|
||||
read_files.insert(input.path);
|
||||
}
|
||||
}
|
||||
"edit_file" => {
|
||||
if let Ok(input) = tool_use.parse_input::<EditFileToolInput>() {
|
||||
cx.assert(
|
||||
read_files.contains(input.path.to_str().unwrap()),
|
||||
format!(
|
||||
"Read before edit: {}",
|
||||
&input.path.file_stem().unwrap().to_str().unwrap()
|
||||
),
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let _ = cx.run_to_end().await?;
|
||||
|
||||
// Adds ignored argument to all but `batch_tool`
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue