assistant edit tool: Track read buffers and notify model of user edits (#26952)

When the model reads file, we'll track the version it read, and let it
know if the user makes edits to the buffer. This helps prevent edit
failures because it'll know to re-read the file before.

Release Notes:

- N/A
This commit is contained in:
Agus Zubiaga 2025-03-17 18:50:16 -03:00 committed by GitHub
parent cb439e672d
commit a05066cd83
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 83 additions and 23 deletions

1
Cargo.lock generated
View file

@ -692,6 +692,7 @@ name = "assistant_tool"
version = "0.1.0"
dependencies = [
"anyhow",
"clock",
"collections",
"derive_more",
"gpui",

View file

@ -361,7 +361,7 @@ impl ActiveThread {
if self.thread.read(cx).all_tools_finished() {
let pending_refresh_buffers = self.thread.update(cx, |thread, cx| {
thread.action_log().update(cx, |action_log, _cx| {
action_log.take_pending_refresh_buffers()
action_log.take_stale_buffers_in_context()
})
});

View file

@ -1,3 +1,4 @@
use std::fmt::Write as _;
use std::io::Write;
use std::sync::Arc;
@ -560,9 +561,39 @@ impl Thread {
request.messages.push(context_message);
}
self.attach_stale_files(&mut request.messages, cx);
request
}
fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
const STALE_FILES_HEADER: &str = "These files changed since last read:";
let mut stale_message = String::new();
for stale_file in self.action_log.read(cx).stale_buffers(cx) {
let Some(file) = stale_file.read(cx).file() else {
continue;
};
if stale_message.is_empty() {
write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
}
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
}
if !stale_message.is_empty() {
let context_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![stale_message.into()],
cache: false,
};
messages.push(context_message);
}
}
pub fn stream_completion(
&mut self,
request: LanguageModelRequest,

View file

@ -14,6 +14,7 @@ path = "src/assistant_tool.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
clock.workspace = true
derive_more.workspace = true
gpui.workspace = true
language.workspace = true

View file

@ -4,7 +4,7 @@ mod tool_working_set;
use std::sync::Arc;
use anyhow::Result;
use collections::HashSet;
use collections::{HashMap, HashSet};
use gpui::Context;
use gpui::{App, Entity, SharedString, Task};
use language::Buffer;
@ -58,31 +58,53 @@ pub trait Tool: 'static + Send + Sync {
/// Tracks actions performed by tools in a thread
#[derive(Debug)]
pub struct ActionLog {
changed_buffers: HashSet<Entity<Buffer>>,
pending_refresh: HashSet<Entity<Buffer>>,
/// Buffers that user manually added to the context, and whose content has
/// changed since the model last saw them.
stale_buffers_in_context: HashSet<Entity<Buffer>>,
/// Buffers that we want to notify the model about when they change.
tracked_buffers: HashMap<Entity<Buffer>, TrackedBuffer>,
}
#[derive(Debug, Default)]
struct TrackedBuffer {
version: clock::Global,
}
impl ActionLog {
/// Creates a new, empty action log.
pub fn new() -> Self {
Self {
changed_buffers: HashSet::default(),
pending_refresh: HashSet::default(),
stale_buffers_in_context: HashSet::default(),
tracked_buffers: HashMap::default(),
}
}
/// Registers buffers that have changed and need refreshing.
pub fn notify_buffers_changed(
&mut self,
buffers: HashSet<Entity<Buffer>>,
_cx: &mut Context<Self>,
) {
self.changed_buffers.extend(buffers.clone());
self.pending_refresh.extend(buffers);
/// Track a buffer as read, so we can notify the model about user edits.
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
tracked_buffer.version = buffer.read(cx).version();
}
/// Mark a buffer as edited, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffers: HashSet<Entity<Buffer>>, cx: &mut Context<Self>) {
for buffer in &buffers {
let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default();
tracked_buffer.version = buffer.read(cx).version();
}
self.stale_buffers_in_context.extend(buffers);
}
/// Iterate over buffers changed since last read or edited by the model
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
self.tracked_buffers
.iter()
.filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version)
.map(|(buffer, _)| buffer)
}
/// Takes and returns the set of buffers pending refresh, clearing internal state.
pub fn take_pending_refresh_buffers(&mut self) -> HashSet<Entity<Buffer>> {
std::mem::take(&mut self.pending_refresh)
pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
std::mem::take(&mut self.stale_buffers_in_context)
}
}

View file

@ -309,9 +309,7 @@ impl EditToolRequest {
}
self.action_log
.update(cx, |log, cx| {
log.notify_buffers_changed(self.changed_buffers, cx)
})
.update(cx, |log, cx| log.buffer_edited(self.changed_buffers, cx))
.log_err();
let errors = self.parser.errors();

View file

@ -49,7 +49,7 @@ impl Tool for ReadFileTool {
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
@ -60,14 +60,15 @@ impl Tool for ReadFileTool {
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path not found in project")));
};
cx.spawn(|cx| async move {
cx.spawn(|mut cx| async move {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})?
.await?;
buffer.read_with(&cx, |buffer, _cx| {
let result = buffer.read_with(&cx, |buffer, _cx| {
if buffer
.file()
.map_or(false, |file| file.disk_state().exists())
@ -76,7 +77,13 @@ impl Tool for ReadFileTool {
} else {
Err(anyhow!("File does not exist"))
}
})?
})??;
action_log.update(&mut cx, |log, cx| {
log.buffer_read(buffer, cx);
})?;
anyhow::Ok(result)
})
}
}