assistant edit tool: Track read buffers and notify model of user edits (#26952)

When the model reads file, we'll track the version it read, and let it
know if the user makes edits to the buffer. This helps prevent edit
failures because it'll know to re-read the file before.

Release Notes:

- N/A
This commit is contained in:
Agus Zubiaga 2025-03-17 18:50:16 -03:00 committed by GitHub
parent cb439e672d
commit a05066cd83
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 83 additions and 23 deletions

1
Cargo.lock generated
View file

@ -692,6 +692,7 @@ name = "assistant_tool"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"clock",
"collections", "collections",
"derive_more", "derive_more",
"gpui", "gpui",

View file

@ -361,7 +361,7 @@ impl ActiveThread {
if self.thread.read(cx).all_tools_finished() { if self.thread.read(cx).all_tools_finished() {
let pending_refresh_buffers = self.thread.update(cx, |thread, cx| { let pending_refresh_buffers = self.thread.update(cx, |thread, cx| {
thread.action_log().update(cx, |action_log, _cx| { thread.action_log().update(cx, |action_log, _cx| {
action_log.take_pending_refresh_buffers() action_log.take_stale_buffers_in_context()
}) })
}); });

View file

@ -1,3 +1,4 @@
use std::fmt::Write as _;
use std::io::Write; use std::io::Write;
use std::sync::Arc; use std::sync::Arc;
@ -560,9 +561,39 @@ impl Thread {
request.messages.push(context_message); request.messages.push(context_message);
} }
self.attach_stale_files(&mut request.messages, cx);
request request
} }
fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
const STALE_FILES_HEADER: &str = "These files changed since last read:";
let mut stale_message = String::new();
for stale_file in self.action_log.read(cx).stale_buffers(cx) {
let Some(file) = stale_file.read(cx).file() else {
continue;
};
if stale_message.is_empty() {
write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
}
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
}
if !stale_message.is_empty() {
let context_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![stale_message.into()],
cache: false,
};
messages.push(context_message);
}
}
pub fn stream_completion( pub fn stream_completion(
&mut self, &mut self,
request: LanguageModelRequest, request: LanguageModelRequest,

View file

@ -14,6 +14,7 @@ path = "src/assistant_tool.rs"
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
collections.workspace = true collections.workspace = true
clock.workspace = true
derive_more.workspace = true derive_more.workspace = true
gpui.workspace = true gpui.workspace = true
language.workspace = true language.workspace = true

View file

@ -4,7 +4,7 @@ mod tool_working_set;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use collections::HashSet; use collections::{HashMap, HashSet};
use gpui::Context; use gpui::Context;
use gpui::{App, Entity, SharedString, Task}; use gpui::{App, Entity, SharedString, Task};
use language::Buffer; use language::Buffer;
@ -58,31 +58,53 @@ pub trait Tool: 'static + Send + Sync {
/// Tracks actions performed by tools in a thread /// Tracks actions performed by tools in a thread
#[derive(Debug)] #[derive(Debug)]
pub struct ActionLog { pub struct ActionLog {
changed_buffers: HashSet<Entity<Buffer>>, /// Buffers that user manually added to the context, and whose content has
pending_refresh: HashSet<Entity<Buffer>>, /// changed since the model last saw them.
stale_buffers_in_context: HashSet<Entity<Buffer>>,
/// Buffers that we want to notify the model about when they change.
tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
}
#[derive(Debug, Default)]
struct TrackedBuffer {
version: clock::Global,
} }
impl ActionLog { impl ActionLog {
/// Creates a new, empty action log. /// Creates a new, empty action log.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
changed_buffers: HashSet::default(), stale_buffers_in_context: HashSet::default(),
pending_refresh: HashSet::default(), tracked_buffers: HashMap::default(),
} }
} }
/// Registers buffers that have changed and need refreshing. /// Track a buffer as read, so we can notify the model about user edits.
pub fn notify_buffers_changed( pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
&mut self, let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
buffers: HashSet<Entity<Buffer>>, tracked_buffer.version = buffer.read(cx).version();
_cx: &mut Context<Self>, }
) {
self.changed_buffers.extend(buffers.clone()); /// Mark a buffer as edited, so we can refresh it in the context
self.pending_refresh.extend(buffers); pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
for buffer in &buffers {
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
tracked_buffer.version = buffer.read(cx).version();
}
self.stale_buffers_in_context.extend(buffers);
}
/// Iterate over buffers changed since last read or edited by the model
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
self.tracked_buffers
.iter()
.filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
.map(|(buffer, _)| buffer)
} }
/// Takes and returns the set of buffers pending refresh, clearing internal state. /// Takes and returns the set of buffers pending refresh, clearing internal state.
pub fn take_pending_refresh_buffers(&mut self) -> HashSet<Entity<Buffer>> { pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
std::mem::take(&mut self.pending_refresh) std::mem::take(&mut self.stale_buffers_in_context)
} }
} }

View file

@ -309,9 +309,7 @@ impl EditToolRequest {
} }
self.action_log self.action_log
.update(cx, |log, cx| { .update(cx, |log, cx| log.buffer_edited(self.changed_buffers, cx))
log.notify_buffers_changed(self.changed_buffers, cx)
})
.log_err(); .log_err();
let errors = self.parser.errors(); let errors = self.parser.errors();

View file

@ -49,7 +49,7 @@ impl Tool for ReadFileTool {
input: serde_json::Value, input: serde_json::Value,
_messages: &[LanguageModelRequestMessage], _messages: &[LanguageModelRequestMessage],
project: Entity<Project>, project: Entity<Project>,
_action_log: Entity<ActionLog>, action_log: Entity<ActionLog>,
cx: &mut App, cx: &mut App,
) -> Task<Result<String>> { ) -> Task<Result<String>> {
let input = match serde_json::from_value::<ReadFileToolInput>(input) { let input = match serde_json::from_value::<ReadFileToolInput>(input) {
@ -60,14 +60,15 @@ impl Tool for ReadFileTool {
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path not found in project"))); return Task::ready(Err(anyhow!("Path not found in project")));
}; };
cx.spawn(|cx| async move {
cx.spawn(|mut cx| async move {
let buffer = cx let buffer = cx
.update(|cx| { .update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx)) project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})? })?
.await?; .await?;
buffer.read_with(&cx, |buffer, _cx| { let result = buffer.read_with(&cx, |buffer, _cx| {
if buffer if buffer
.file() .file()
.map_or(false, |file| file.disk_state().exists()) .map_or(false, |file| file.disk_state().exists())
@ -76,7 +77,13 @@ impl Tool for ReadFileTool {
} else { } else {
Err(anyhow!("File does not exist")) Err(anyhow!("File does not exist"))
} }
})? })??;
action_log.update(&mut cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
anyhow::Ok(result)
}) })
} }
} }