assistant edit tool: Use buffer search and replace in background (#26679)
Instead of getting the whole text from the buffer, replacing with `String::replace`, and getting a whole diff, we'll now use `SearchQuery` to get a range, diff only that range, and apply it (all in the background). When we match zero strings, we'll record a "bad search", keep going and report it to the model at the end. Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
6767e98e00
commit
8ec0309645
2 changed files with 219 additions and 111 deletions
|
@ -6,16 +6,17 @@ use assistant_tool::Tool;
|
||||||
use collections::HashSet;
|
use collections::HashSet;
|
||||||
use edit_action::{EditAction, EditActionParser};
|
use edit_action::{EditAction, EditActionParser};
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use gpui::{App, Entity, Task};
|
use gpui::{App, AsyncApp, Entity, Task};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
|
||||||
};
|
};
|
||||||
use log::{EditToolLog, EditToolRequestId};
|
use log::{EditToolLog, EditToolRequestId};
|
||||||
use project::Project;
|
use project::{search::SearchQuery, Project};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use util::paths::PathMatcher;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
@ -93,7 +94,7 @@ impl Tool for EditFilesTool {
|
||||||
});
|
});
|
||||||
|
|
||||||
let task =
|
let task =
|
||||||
EditFilesTool::run(input, messages, project, Some((log.clone(), req_id)), cx);
|
EditToolRequest::new(input, messages, project, Some((log.clone(), req_id)), cx);
|
||||||
|
|
||||||
cx.spawn(|mut cx| async move {
|
cx.spawn(|mut cx| async move {
|
||||||
let result = task.await;
|
let result = task.await;
|
||||||
|
@ -112,13 +113,33 @@ impl Tool for EditFilesTool {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
None => EditFilesTool::run(input, messages, project, None, cx),
|
None => EditToolRequest::new(input, messages, project, None, cx),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EditFilesTool {
|
struct EditToolRequest {
|
||||||
fn run(
|
parser: EditActionParser,
|
||||||
|
changed_buffers: HashSet<Entity<language::Buffer>>,
|
||||||
|
bad_searches: Vec<BadSearch>,
|
||||||
|
project: Entity<Project>,
|
||||||
|
log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum DiffResult {
|
||||||
|
BadSearch(BadSearch),
|
||||||
|
Diff(language::Diff),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct BadSearch {
|
||||||
|
file_path: String,
|
||||||
|
search: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EditToolRequest {
|
||||||
|
fn new(
|
||||||
input: EditFilesToolInput,
|
input: EditFilesToolInput,
|
||||||
messages: &[LanguageModelRequestMessage],
|
messages: &[LanguageModelRequestMessage],
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
@ -147,121 +168,208 @@ impl EditFilesTool {
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.spawn(|mut cx| async move {
|
cx.spawn(|mut cx| async move {
|
||||||
let request = LanguageModelRequest {
|
let llm_request = LanguageModelRequest {
|
||||||
messages,
|
messages,
|
||||||
tools: vec![],
|
tools: vec![],
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut parser = EditActionParser::new();
|
let stream = model.stream_completion_text(llm_request, &cx);
|
||||||
|
|
||||||
let stream = model.stream_completion_text(request, &cx);
|
|
||||||
let mut chunks = stream.await?;
|
let mut chunks = stream.await?;
|
||||||
|
|
||||||
let mut changed_buffers = HashSet::default();
|
let mut request = Self {
|
||||||
let mut applied_edits = 0;
|
parser: EditActionParser::new(),
|
||||||
|
changed_buffers: HashSet::default(),
|
||||||
let log = log.clone();
|
bad_searches: Vec::new(),
|
||||||
|
project,
|
||||||
|
log,
|
||||||
|
};
|
||||||
|
|
||||||
while let Some(chunk) = chunks.stream.next().await {
|
while let Some(chunk) = chunks.stream.next().await {
|
||||||
let chunk = chunk?;
|
request.process_response_chunk(&chunk?, &mut cx).await?;
|
||||||
|
}
|
||||||
|
|
||||||
let new_actions = parser.parse_chunk(&chunk);
|
request.finalize(&mut cx).await
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if let Some((ref log, req_id)) = log {
|
async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
|
||||||
log.update(&mut cx, |log, cx| {
|
let new_actions = self.parser.parse_chunk(chunk);
|
||||||
log.push_editor_response_chunk(req_id, &chunk, &new_actions, cx)
|
|
||||||
|
if let Some((ref log, req_id)) = self.log {
|
||||||
|
log.update(cx, |log, cx| {
|
||||||
|
log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
|
||||||
})
|
})
|
||||||
.log_err();
|
.log_err();
|
||||||
}
|
}
|
||||||
|
|
||||||
for action in new_actions {
|
for action in new_actions {
|
||||||
let project_path = project.read_with(&cx, |project, cx| {
|
self.apply_action(action, cx).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn apply_action(&mut self, action: EditAction, cx: &mut AsyncApp) -> Result<()> {
|
||||||
|
let project_path = self.project.read_with(cx, |project, cx| {
|
||||||
project
|
project
|
||||||
.find_project_path(action.file_path(), cx)
|
.find_project_path(action.file_path(), cx)
|
||||||
.context("Path not found in project")
|
.context("Path not found in project")
|
||||||
})??;
|
})??;
|
||||||
|
|
||||||
let buffer = project
|
let buffer = self
|
||||||
.update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
|
.project
|
||||||
|
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let diff = buffer
|
let result = match action {
|
||||||
.read_with(&cx, |buffer, cx| {
|
|
||||||
let new_text = match action {
|
|
||||||
EditAction::Replace {
|
EditAction::Replace {
|
||||||
file_path,
|
|
||||||
old,
|
old,
|
||||||
new,
|
new,
|
||||||
|
file_path,
|
||||||
} => {
|
} => {
|
||||||
// TODO: Replace in background?
|
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||||
let text = buffer.text();
|
|
||||||
if text.contains(&old) {
|
cx.background_executor()
|
||||||
text.replace(&old, &new)
|
.spawn(Self::replace_diff(old, new, file_path, snapshot))
|
||||||
} else {
|
.await
|
||||||
return Err(anyhow!(
|
}
|
||||||
"Could not find search text in {}",
|
EditAction::Write { content, .. } => Ok(DiffResult::Diff(
|
||||||
file_path.display()
|
buffer
|
||||||
));
|
.read_with(cx, |buffer, cx| buffer.diff(content, cx))?
|
||||||
|
.await,
|
||||||
|
)),
|
||||||
|
}?;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
DiffResult::BadSearch(invalid_replace) => {
|
||||||
|
self.bad_searches.push(invalid_replace);
|
||||||
|
}
|
||||||
|
DiffResult::Diff(diff) => {
|
||||||
|
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
|
||||||
|
|
||||||
|
self.changed_buffers.insert(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EditAction::Write { content, .. } => content,
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn replace_diff(
|
||||||
|
old: String,
|
||||||
|
new: String,
|
||||||
|
file_path: std::path::PathBuf,
|
||||||
|
snapshot: language::BufferSnapshot,
|
||||||
|
) -> Result<DiffResult> {
|
||||||
|
let query = SearchQuery::text(
|
||||||
|
old.clone(),
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
PathMatcher::new(&[])?,
|
||||||
|
PathMatcher::new(&[])?,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let matches = query.search(&snapshot, None).await;
|
||||||
|
|
||||||
|
if matches.is_empty() {
|
||||||
|
return Ok(DiffResult::BadSearch(BadSearch {
|
||||||
|
search: new.clone(),
|
||||||
|
file_path: file_path.display().to_string(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
let edit_range = matches[0].clone();
|
||||||
|
let diff = language::text_diff(&old, &new);
|
||||||
|
|
||||||
|
let edits = diff
|
||||||
|
.into_iter()
|
||||||
|
.map(|(old_range, text)| {
|
||||||
|
let start = edit_range.start + old_range.start;
|
||||||
|
let end = edit_range.start + old_range.end;
|
||||||
|
(start..end, text)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let diff = language::Diff {
|
||||||
|
base_version: snapshot.version().clone(),
|
||||||
|
line_ending: snapshot.line_ending(),
|
||||||
|
edits,
|
||||||
};
|
};
|
||||||
|
|
||||||
anyhow::Ok(buffer.diff(new_text, cx))
|
anyhow::Ok(DiffResult::Diff(diff))
|
||||||
})??
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let _clock =
|
|
||||||
buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
|
|
||||||
|
|
||||||
changed_buffers.insert(buffer);
|
|
||||||
|
|
||||||
applied_edits += 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut answer = match changed_buffers.len() {
|
async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
|
||||||
|
let mut answer = match self.changed_buffers.len() {
|
||||||
0 => "No files were edited.".to_string(),
|
0 => "No files were edited.".to_string(),
|
||||||
1 => "Successfully edited ".to_string(),
|
1 => "Successfully edited ".to_string(),
|
||||||
_ => "Successfully edited these files:\n\n".to_string(),
|
_ => "Successfully edited these files:\n\n".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Save each buffer once at the end
|
// Save each buffer once at the end
|
||||||
for buffer in changed_buffers {
|
for buffer in self.changed_buffers {
|
||||||
project
|
let (path, save_task) = self.project.update(cx, |project, cx| {
|
||||||
.update(&mut cx, |project, cx| {
|
let path = buffer
|
||||||
if let Some(file) = buffer.read(&cx).file() {
|
.read(cx)
|
||||||
let _ = writeln!(&mut answer, "{}", &file.full_path(cx).display());
|
.file()
|
||||||
|
.map(|file| file.path().display().to_string());
|
||||||
|
|
||||||
|
let task = project.save_buffer(buffer.clone(), cx);
|
||||||
|
|
||||||
|
(path, task)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
save_task.await?;
|
||||||
|
|
||||||
|
if let Some(path) = path {
|
||||||
|
writeln!(&mut answer, "{}", path)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
project.save_buffer(buffer, cx)
|
let errors = self.parser.errors();
|
||||||
})?
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let errors = parser.errors();
|
if errors.is_empty() && self.bad_searches.is_empty() {
|
||||||
|
|
||||||
if errors.is_empty() {
|
|
||||||
Ok(answer.trim_end().to_string())
|
Ok(answer.trim_end().to_string())
|
||||||
} else {
|
} else {
|
||||||
let error_message = errors
|
if !self.bad_searches.is_empty() {
|
||||||
.iter()
|
writeln!(
|
||||||
.map(|e| e.to_string())
|
&mut answer,
|
||||||
.collect::<Vec<_>>()
|
"\nThese searches failed because they didn't match any strings:"
|
||||||
.join("\n");
|
)?;
|
||||||
|
|
||||||
if applied_edits > 0 {
|
for replace in self.bad_searches {
|
||||||
Err(anyhow!(
|
writeln!(
|
||||||
"Applied {} edit(s), but some blocks failed to parse:\n{}",
|
&mut answer,
|
||||||
applied_edits,
|
"- '{}' does not appear in `{}`",
|
||||||
error_message
|
replace.search.replace("\r", "\\r").replace("\n", "\\n"),
|
||||||
))
|
replace.file_path
|
||||||
} else {
|
)?;
|
||||||
Err(anyhow!(error_message))
|
}
|
||||||
|
|
||||||
|
writeln!(&mut answer, "Make sure to use exact searches.")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.is_empty() {
|
||||||
|
writeln!(
|
||||||
|
&mut answer,
|
||||||
|
"\nThese SEARCH/REPLACE blocks failed to parse:"
|
||||||
|
)?;
|
||||||
|
|
||||||
|
for error in errors {
|
||||||
|
writeln!(&mut answer, "- {}", error)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
writeln!(
|
||||||
|
&mut answer,
|
||||||
|
"\nYou can fix errors by running the tool again. You can include instructions,\
|
||||||
|
but errors are part of the conversation so you don't need to repeat them."
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Err(anyhow!(answer))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -526,8 +526,8 @@ impl DerefMut for ChunkRendererContext<'_, '_> {
|
||||||
/// A set of edits to a given version of a buffer, computed asynchronously.
|
/// A set of edits to a given version of a buffer, computed asynchronously.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Diff {
|
pub struct Diff {
|
||||||
pub(crate) base_version: clock::Global,
|
pub base_version: clock::Global,
|
||||||
line_ending: LineEnding,
|
pub line_ending: LineEnding,
|
||||||
pub edits: Vec<(Range<usize>, Arc<str>)>,
|
pub edits: Vec<(Range<usize>, Arc<str>)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue