assistant edit tool: Use buffer search and replace in background (#26679)

Instead of getting the whole text from the buffer, replacing with
`String::replace`, and getting a whole diff, we'll now use `SearchQuery`
to get a range, diff only that range, and apply it (all in the
background).

When we match zero strings, we'll record a "bad search", keep going and
report it to the model at the end.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Agus Zubiaga 2025-03-13 12:25:49 -03:00 committed by GitHub
parent 6767e98e00
commit 8ec0309645
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 219 additions and 111 deletions

View file

@ -6,16 +6,17 @@ use assistant_tool::Tool;
use collections::HashSet; use collections::HashSet;
use edit_action::{EditAction, EditActionParser}; use edit_action::{EditAction, EditActionParser};
use futures::StreamExt; use futures::StreamExt;
use gpui::{App, Entity, Task}; use gpui::{App, AsyncApp, Entity, Task};
use language_model::{ use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
}; };
use log::{EditToolLog, EditToolRequestId}; use log::{EditToolLog, EditToolRequestId};
use project::Project; use project::{search::SearchQuery, Project};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Write; use std::fmt::Write;
use std::sync::Arc; use std::sync::Arc;
use util::paths::PathMatcher;
use util::ResultExt; use util::ResultExt;
#[derive(Debug, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Serialize, Deserialize, JsonSchema)]
@ -93,7 +94,7 @@ impl Tool for EditFilesTool {
}); });
let task = let task =
EditFilesTool::run(input, messages, project, Some((log.clone(), req_id)), cx); EditToolRequest::new(input, messages, project, Some((log.clone(), req_id)), cx);
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let result = task.await; let result = task.await;
@ -112,13 +113,33 @@ impl Tool for EditFilesTool {
}) })
} }
None => EditFilesTool::run(input, messages, project, None, cx), None => EditToolRequest::new(input, messages, project, None, cx),
} }
} }
} }
impl EditFilesTool { struct EditToolRequest {
fn run( parser: EditActionParser,
changed_buffers: HashSet<Entity<language::Buffer>>,
bad_searches: Vec<BadSearch>,
project: Entity<Project>,
log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
}
#[derive(Debug)]
enum DiffResult {
BadSearch(BadSearch),
Diff(language::Diff),
}
#[derive(Debug)]
struct BadSearch {
file_path: String,
search: String,
}
impl EditToolRequest {
fn new(
input: EditFilesToolInput, input: EditFilesToolInput,
messages: &[LanguageModelRequestMessage], messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
@ -147,121 +168,208 @@ impl EditFilesTool {
}); });
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let request = LanguageModelRequest { let llm_request = LanguageModelRequest {
messages, messages,
tools: vec![], tools: vec![],
stop: vec![], stop: vec![],
temperature: Some(0.0), temperature: Some(0.0),
}; };
let mut parser = EditActionParser::new(); let stream = model.stream_completion_text(llm_request, &cx);
let stream = model.stream_completion_text(request, &cx);
let mut chunks = stream.await?; let mut chunks = stream.await?;
let mut changed_buffers = HashSet::default(); let mut request = Self {
let mut applied_edits = 0; parser: EditActionParser::new(),
changed_buffers: HashSet::default(),
let log = log.clone(); bad_searches: Vec::new(),
project,
log,
};
while let Some(chunk) = chunks.stream.next().await { while let Some(chunk) = chunks.stream.next().await {
let chunk = chunk?; request.process_response_chunk(&chunk?, &mut cx).await?;
}
let new_actions = parser.parse_chunk(&chunk); request.finalize(&mut cx).await
})
}
if let Some((ref log, req_id)) = log { async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
log.update(&mut cx, |log, cx| { let new_actions = self.parser.parse_chunk(chunk);
log.push_editor_response_chunk(req_id, &chunk, &new_actions, cx)
if let Some((ref log, req_id)) = self.log {
log.update(cx, |log, cx| {
log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
}) })
.log_err(); .log_err();
} }
for action in new_actions { for action in new_actions {
let project_path = project.read_with(&cx, |project, cx| { self.apply_action(action, cx).await?;
}
Ok(())
}
async fn apply_action(&mut self, action: EditAction, cx: &mut AsyncApp) -> Result<()> {
let project_path = self.project.read_with(cx, |project, cx| {
project project
.find_project_path(action.file_path(), cx) .find_project_path(action.file_path(), cx)
.context("Path not found in project") .context("Path not found in project")
})??; })??;
let buffer = project let buffer = self
.update(&mut cx, |project, cx| project.open_buffer(project_path, cx))? .project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?; .await?;
let diff = buffer let result = match action {
.read_with(&cx, |buffer, cx| {
let new_text = match action {
EditAction::Replace { EditAction::Replace {
file_path,
old, old,
new, new,
file_path,
} => { } => {
// TODO: Replace in background? let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let text = buffer.text();
if text.contains(&old) { cx.background_executor()
text.replace(&old, &new) .spawn(Self::replace_diff(old, new, file_path, snapshot))
} else { .await
return Err(anyhow!( }
"Could not find search text in {}", EditAction::Write { content, .. } => Ok(DiffResult::Diff(
file_path.display() buffer
)); .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
.await,
)),
}?;
match result {
DiffResult::BadSearch(invalid_replace) => {
self.bad_searches.push(invalid_replace);
}
DiffResult::Diff(diff) => {
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
self.changed_buffers.insert(buffer);
} }
} }
EditAction::Write { content, .. } => content,
Ok(())
}
async fn replace_diff(
old: String,
new: String,
file_path: std::path::PathBuf,
snapshot: language::BufferSnapshot,
) -> Result<DiffResult> {
let query = SearchQuery::text(
old.clone(),
false,
true,
true,
PathMatcher::new(&[])?,
PathMatcher::new(&[])?,
None,
)?;
let matches = query.search(&snapshot, None).await;
if matches.is_empty() {
return Ok(DiffResult::BadSearch(BadSearch {
search: new.clone(),
file_path: file_path.display().to_string(),
}));
}
let edit_range = matches[0].clone();
let diff = language::text_diff(&old, &new);
let edits = diff
.into_iter()
.map(|(old_range, text)| {
let start = edit_range.start + old_range.start;
let end = edit_range.start + old_range.end;
(start..end, text)
})
.collect::<Vec<_>>();
let diff = language::Diff {
base_version: snapshot.version().clone(),
line_ending: snapshot.line_ending(),
edits,
}; };
anyhow::Ok(buffer.diff(new_text, cx)) anyhow::Ok(DiffResult::Diff(diff))
})??
.await;
let _clock =
buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
changed_buffers.insert(buffer);
applied_edits += 1;
}
} }
let mut answer = match changed_buffers.len() { async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
let mut answer = match self.changed_buffers.len() {
0 => "No files were edited.".to_string(), 0 => "No files were edited.".to_string(),
1 => "Successfully edited ".to_string(), 1 => "Successfully edited ".to_string(),
_ => "Successfully edited these files:\n\n".to_string(), _ => "Successfully edited these files:\n\n".to_string(),
}; };
// Save each buffer once at the end // Save each buffer once at the end
for buffer in changed_buffers { for buffer in self.changed_buffers {
project let (path, save_task) = self.project.update(cx, |project, cx| {
.update(&mut cx, |project, cx| { let path = buffer
if let Some(file) = buffer.read(&cx).file() { .read(cx)
let _ = writeln!(&mut answer, "{}", &file.full_path(cx).display()); .file()
.map(|file| file.path().display().to_string());
let task = project.save_buffer(buffer.clone(), cx);
(path, task)
})?;
save_task.await?;
if let Some(path) = path {
writeln!(&mut answer, "{}", path)?;
}
} }
project.save_buffer(buffer, cx) let errors = self.parser.errors();
})?
.await?;
}
let errors = parser.errors(); if errors.is_empty() && self.bad_searches.is_empty() {
if errors.is_empty() {
Ok(answer.trim_end().to_string()) Ok(answer.trim_end().to_string())
} else { } else {
let error_message = errors if !self.bad_searches.is_empty() {
.iter() writeln!(
.map(|e| e.to_string()) &mut answer,
.collect::<Vec<_>>() "\nThese searches failed because they didn't match any strings:"
.join("\n"); )?;
if applied_edits > 0 { for replace in self.bad_searches {
Err(anyhow!( writeln!(
"Applied {} edit(s), but some blocks failed to parse:\n{}", &mut answer,
applied_edits, "- '{}' does not appear in `{}`",
error_message replace.search.replace("\r", "\\r").replace("\n", "\\n"),
)) replace.file_path
} else { )?;
Err(anyhow!(error_message)) }
writeln!(&mut answer, "Make sure to use exact searches.")?;
}
if !errors.is_empty() {
writeln!(
&mut answer,
"\nThese SEARCH/REPLACE blocks failed to parse:"
)?;
for error in errors {
writeln!(&mut answer, "- {}", error)?;
} }
} }
})
writeln!(
&mut answer,
"\nYou can fix errors by running the tool again. You can include instructions,\
but errors are part of the conversation so you don't need to repeat them."
)?;
Err(anyhow!(answer))
}
} }
} }

View file

@ -526,8 +526,8 @@ impl DerefMut for ChunkRendererContext<'_, '_> {
/// A set of edits to a given version of a buffer, computed asynchronously. /// A set of edits to a given version of a buffer, computed asynchronously.
#[derive(Debug)] #[derive(Debug)]
pub struct Diff { pub struct Diff {
pub(crate) base_version: clock::Global, pub base_version: clock::Global,
line_ending: LineEnding, pub line_ending: LineEnding,
pub edits: Vec<(Range<usize>, Arc<str>)>, pub edits: Vec<(Range<usize>, Arc<str>)>,
} }