assistant edit tool: Fuzzy match search block (#26935)

Release Notes:

- N/A

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Agus Zubiaga 2025-03-17 15:33:20 -03:00 committed by GitHub
parent 798af67dc1
commit 94b63808e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 258 additions and 79 deletions

2
Cargo.lock generated
View file

@ -716,6 +716,7 @@ dependencies = [
"gpui", "gpui",
"language", "language",
"language_model", "language_model",
"pretty_assertions",
"project", "project",
"rand 0.8.5", "rand 0.8.5",
"release_channel", "release_channel",
@ -725,6 +726,7 @@ dependencies = [
"settings", "settings",
"theme", "theme",
"ui", "ui",
"unindent",
"util", "util",
"workspace", "workspace",
"worktree", "worktree",

View file

@ -48,7 +48,12 @@ fn main() {
let crate_dir = PathBuf::from("../zed-agent-bench"); let crate_dir = PathBuf::from("../zed-agent-bench");
let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap(); let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
let repos_dir = crate_dir.join("repos").canonicalize().unwrap();
let repos_dir = crate_dir.join("repos");
if !repos_dir.exists() {
std::fs::create_dir_all(&repos_dir).unwrap();
}
let repos_dir = repos_dir.canonicalize().unwrap();
let all_evals = std::fs::read_dir(&evaluation_data_dir) let all_evals = std::fs::read_dir(&evaluation_data_dir)
.unwrap() .unwrap()

View file

@ -38,5 +38,7 @@ rand.workspace = true
collections = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] }
unindent.workspace = true
workspace = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] }

View file

@ -1,5 +1,6 @@
mod edit_action; mod edit_action;
pub mod log; pub mod log;
mod resolve_search_block;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use assistant_tool::{ActionLog, Tool}; use assistant_tool::{ActionLog, Tool};
@ -7,16 +8,17 @@ use collections::HashSet;
use edit_action::{EditAction, EditActionParser}; use edit_action::{EditAction, EditActionParser};
use futures::StreamExt; use futures::StreamExt;
use gpui::{App, AsyncApp, Entity, Task}; use gpui::{App, AsyncApp, Entity, Task};
use language::OffsetRangeExt;
use language_model::{ use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
}; };
use log::{EditToolLog, EditToolRequestId}; use log::{EditToolLog, EditToolRequestId};
use project::{search::SearchQuery, Project}; use project::Project;
use resolve_search_block::resolve_search_block;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Write; use std::fmt::Write;
use std::sync::Arc; use std::sync::Arc;
use util::paths::PathMatcher;
use util::ResultExt; use util::ResultExt;
#[derive(Debug, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Serialize, Deserialize, JsonSchema)]
@ -129,24 +131,11 @@ struct EditToolRequest {
parser: EditActionParser, parser: EditActionParser,
output: String, output: String,
changed_buffers: HashSet<Entity<language::Buffer>>, changed_buffers: HashSet<Entity<language::Buffer>>,
bad_searches: Vec<BadSearch>,
project: Entity<Project>, project: Entity<Project>,
action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>, tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
} }
#[derive(Debug)]
enum DiffResult {
BadSearch(BadSearch),
Diff(language::Diff),
}
#[derive(Debug)]
struct BadSearch {
file_path: String,
search: String,
}
impl EditToolRequest { impl EditToolRequest {
fn new( fn new(
input: EditFilesToolInput, input: EditFilesToolInput,
@ -204,7 +193,6 @@ impl EditToolRequest {
// we start with the success header so we don't need to shift the output in the common case // we start with the success header so we don't need to shift the output in the common case
output: Self::SUCCESS_OUTPUT_HEADER.to_string(), output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
changed_buffers: HashSet::default(), changed_buffers: HashSet::default(),
bad_searches: Vec::new(),
action_log, action_log,
project, project,
tool_log, tool_log,
@ -251,36 +239,30 @@ impl EditToolRequest {
.update(cx, |project, cx| project.open_buffer(project_path, cx))? .update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?; .await?;
let result = match action { let diff = match action {
EditAction::Replace { EditAction::Replace {
old, old,
new, new,
file_path, file_path: _,
} => { } => {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
cx.background_executor() let diff = cx
.spawn(Self::replace_diff(old, new, file_path, snapshot)) .background_executor()
.await .spawn(Self::replace_diff(old, new, snapshot))
.await;
anyhow::Ok(diff)
} }
EditAction::Write { content, .. } => Ok(DiffResult::Diff( EditAction::Write { content, .. } => Ok(buffer
buffer
.read_with(cx, |buffer, cx| buffer.diff(content, cx))? .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
.await, .await),
)),
}?; }?;
match result {
DiffResult::BadSearch(invalid_replace) => {
self.bad_searches.push(invalid_replace);
}
DiffResult::Diff(diff) => {
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?; let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
write!(&mut self.output, "\n\n{}", source)?; write!(&mut self.output, "\n\n{}", source)?;
self.changed_buffers.insert(buffer); self.changed_buffers.insert(buffer);
}
}
Ok(()) Ok(())
} }
@ -288,29 +270,9 @@ impl EditToolRequest {
async fn replace_diff( async fn replace_diff(
old: String, old: String,
new: String, new: String,
file_path: std::path::PathBuf,
snapshot: language::BufferSnapshot, snapshot: language::BufferSnapshot,
) -> Result<DiffResult> { ) -> language::Diff {
let query = SearchQuery::text( let edit_range = resolve_search_block(&snapshot, &old).to_offset(&snapshot);
old.clone(),
false,
true,
true,
PathMatcher::new(&[])?,
PathMatcher::new(&[])?,
None,
)?;
let matches = query.search(&snapshot, None).await;
if matches.is_empty() {
return Ok(DiffResult::BadSearch(BadSearch {
search: new.clone(),
file_path: file_path.display().to_string(),
}));
}
let edit_range = matches[0].clone();
let diff = language::text_diff(&old, &new); let diff = language::text_diff(&old, &new);
let edits = diff let edits = diff
@ -328,7 +290,7 @@ impl EditToolRequest {
edits, edits,
}; };
anyhow::Ok(DiffResult::Diff(diff)) diff
} }
const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:"; const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
@ -354,7 +316,7 @@ impl EditToolRequest {
let errors = self.parser.errors(); let errors = self.parser.errors();
if errors.is_empty() && self.bad_searches.is_empty() { if errors.is_empty() {
if changed_buffer_count == 0 { if changed_buffer_count == 0 {
return Err(anyhow!( return Err(anyhow!(
"The instructions didn't lead to any changes. You might need to consult the file contents first." "The instructions didn't lead to any changes. You might need to consult the file contents first."
@ -377,24 +339,6 @@ impl EditToolRequest {
); );
} }
if !self.bad_searches.is_empty() {
writeln!(
&mut output,
"\n\nThese searches failed because they didn't match any strings:"
)?;
for replace in self.bad_searches {
writeln!(
&mut output,
"- '{}' does not appear in `{}`",
replace.search.replace("\r", "\\r").replace("\n", "\\n"),
replace.file_path
)?;
}
write!(&mut output, "Make sure to use exact searches.")?;
}
if !errors.is_empty() { if !errors.is_empty() {
writeln!( writeln!(
&mut output, &mut output,

View file

@ -0,0 +1,226 @@
use language::{Anchor, Bias, BufferSnapshot};
use std::ops::Range;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
cost: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(cost: u32, direction: SearchDirection) -> Self {
Self { cost, direction }
}
}
struct SearchMatrix {
cols: usize,
data: Vec<SearchState>,
}
impl SearchMatrix {
fn new(rows: usize, cols: usize) -> Self {
SearchMatrix {
cols,
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
}
}
fn get(&self, row: usize, col: usize) -> SearchState {
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
self.data[row * self.cols + col] = cost;
}
}
pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
const INSERTION_COST: u32 = 3;
const DELETION_COST: u32 = 10;
const WHITESPACE_INSERTION_COST: u32 = 1;
const WHITESPACE_DELETION_COST: u32 = 1;
let buffer_len = buffer.len();
let query_len = search_query.len();
let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
let mut leading_deletion_cost = 0_u32;
for (row, query_byte) in search_query.bytes().enumerate() {
let deletion_cost = if query_byte.is_ascii_whitespace() {
WHITESPACE_DELETION_COST
} else {
DELETION_COST
};
leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
matrix.set(
row + 1,
0,
SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
);
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
WHITESPACE_INSERTION_COST
} else {
INSERTION_COST
};
let up = SearchState::new(
matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
SearchDirection::Up,
);
let left = SearchState::new(
matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_byte == *buffer_byte {
matrix.get(row, col).cost
} else {
matrix
.get(row, col)
.cost
.saturating_add(deletion_cost + insertion_cost)
},
SearchDirection::Diagonal,
);
matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
}
}
// Traceback to find the best match
let mut best_buffer_end = buffer_len;
let mut best_cost = u32::MAX;
for col in 1..=buffer_len {
let cost = matrix.get(query_len, col).cost;
if cost < best_cost {
best_cost = cost;
best_buffer_end = col;
}
}
let mut query_ix = query_len;
let mut buffer_ix = best_buffer_end;
while query_ix > 0 && buffer_ix > 0 {
let current = matrix.get(query_ix, buffer_ix);
match current.direction {
SearchDirection::Diagonal => {
query_ix -= 1;
buffer_ix -= 1;
}
SearchDirection::Up => {
query_ix -= 1;
}
SearchDirection::Left => {
buffer_ix -= 1;
}
}
}
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
start.column = 0;
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
if end.column > 0 {
end.column = buffer.line_len(end.row);
}
buffer.anchor_after(start)..buffer.anchor_before(end)
}
#[cfg(test)]
mod tests {
use crate::edit_files_tool::resolve_search_block::resolve_search_block;
use gpui::{prelude::*, App};
use language::{Buffer, OffsetRangeExt as _};
use unindent::Unindent as _;
use util::test::{generate_marked_text, marked_text_ranges};
#[gpui::test]
fn test_resolve_search_block(cx: &mut App) {
assert_resolved(
concat!(
" Lorem\n",
"« ipsum\n",
" dolor sit amet»\n",
" consecteur",
),
"ipsum\ndolor",
cx,
);
assert_resolved(
&"
«fn foo1(a: usize) -> usize {
40
}»
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"fn foo1(b: usize) {\n40\n}",
cx,
);
assert_resolved(
&"
fn main() {
« Foo
.bar()
.baz()
.qux()»
}
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"Foo.bar.baz.qux()",
cx,
);
assert_resolved(
&"
class Something {
one() { return 1; }
« two() { return 2222; }
three() { return 333; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
» seven() { return 7; }
eight() { return 8; }
}
"
.unindent(),
&"
two() { return 2222; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
"
.unindent(),
cx,
);
}
#[track_caller]
fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) {
let (text, _) = marked_text_ranges(text_with_expected_range, false);
let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
let snapshot = buffer.read(cx).snapshot();
let range = resolve_search_block(&snapshot, query).to_offset(&snapshot);
let text_with_actual_range = generate_marked_text(&text, &[range], false);
pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
}
}