diff --git a/Cargo.lock b/Cargo.lock index 36182efd3d..6b0b113fdd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -692,6 +692,7 @@ name = "assistant_tool" version = "0.1.0" dependencies = [ "anyhow", + "clock", "collections", "derive_more", "gpui", diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 3ab8e0fabf..a9797166f3 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -361,7 +361,7 @@ impl ActiveThread { if self.thread.read(cx).all_tools_finished() { let pending_refresh_buffers = self.thread.update(cx, |thread, cx| { thread.action_log().update(cx, |action_log, _cx| { - action_log.take_pending_refresh_buffers() + action_log.take_stale_buffers_in_context() }) }); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index ddd6732446..9e97a9d7c4 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,3 +1,4 @@ +use std::fmt::Write as _; use std::io::Write; use std::sync::Arc; @@ -560,9 +561,39 @@ impl Thread { request.messages.push(context_message); } + self.attach_stale_files(&mut request.messages, cx); + request } + fn attach_stale_files(&self, messages: &mut Vec, cx: &App) { + const STALE_FILES_HEADER: &str = "These files changed since last read:"; + + let mut stale_message = String::new(); + + for stale_file in self.action_log.read(cx).stale_buffers(cx) { + let Some(file) = stale_file.read(cx).file() else { + continue; + }; + + if stale_message.is_empty() { + write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok(); + } + + writeln!(&mut stale_message, "- {}", file.path().display()).ok(); + } + + if !stale_message.is_empty() { + let context_message = LanguageModelRequestMessage { + role: Role::User, + content: vec![stale_message.into()], + cache: false, + }; + + messages.push(context_message); + } + } + pub fn stream_completion( &mut self, request: LanguageModelRequest, diff --git a/crates/assistant_tool/Cargo.toml b/crates/assistant_tool/Cargo.toml index 70022fc02c..040a906bf3 100644 --- a/crates/assistant_tool/Cargo.toml +++ b/crates/assistant_tool/Cargo.toml @@ -14,6 +14,7 @@ path = "src/assistant_tool.rs" [dependencies] anyhow.workspace = true collections.workspace = true +clock.workspace = true derive_more.workspace = true gpui.workspace = true language.workspace = true diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index b466931d89..22564bc37f 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -4,7 +4,7 @@ mod tool_working_set; use std::sync::Arc; use anyhow::Result; -use collections::HashSet; +use collections::{HashMap, HashSet}; use gpui::Context; use gpui::{App, Entity, SharedString, Task}; use language::Buffer; @@ -58,31 +58,53 @@ pub trait Tool: 'static + Send + Sync { /// Tracks actions performed by tools in a thread #[derive(Debug)] pub struct ActionLog { - changed_buffers: HashSet>, - pending_refresh: HashSet>, + /// Buffers that user manually added to the context, and whose content has + /// changed since the model last saw them. + stale_buffers_in_context: HashSet>, + /// Buffers that we want to notify the model about when they change. + tracked_buffers: HashMap, TrackedBuffer>, +} + +#[derive(Debug, Default)] +struct TrackedBuffer { + version: clock::Global, } impl ActionLog { /// Creates a new, empty action log. pub fn new() -> Self { Self { - changed_buffers: HashSet::default(), - pending_refresh: HashSet::default(), + stale_buffers_in_context: HashSet::default(), + tracked_buffers: HashMap::default(), } } - /// Registers buffers that have changed and need refreshing. - pub fn notify_buffers_changed( - &mut self, - buffers: HashSet>, - _cx: &mut Context, - ) { - self.changed_buffers.extend(buffers.clone()); - self.pending_refresh.extend(buffers); + /// Track a buffer as read, so we can notify the model about user edits. + pub fn buffer_read(&mut self, buffer: Entity, cx: &mut Context) { + let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default(); + tracked_buffer.version = buffer.read(cx).version(); + } + + /// Mark a buffer as edited, so we can refresh it in the context + pub fn buffer_edited(&mut self, buffers: HashSet>, cx: &mut Context) { + for buffer in &buffers { + let tracked_buffer = self.tracked_buffers.entry(buffer.clone()).or_default(); + tracked_buffer.version = buffer.read(cx).version(); + } + + self.stale_buffers_in_context.extend(buffers); + } + + /// Iterate over buffers changed since last read or edited by the model + pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator> { + self.tracked_buffers + .iter() + .filter(|(buffer, tracked)| tracked.version != buffer.read(cx).version) + .map(|(buffer, _)| buffer) } /// Takes and returns the set of buffers pending refresh, clearing internal state. - pub fn take_pending_refresh_buffers(&mut self) -> HashSet> { - std::mem::take(&mut self.pending_refresh) + pub fn take_stale_buffers_in_context(&mut self) -> HashSet> { + std::mem::take(&mut self.stale_buffers_in_context) } } diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index 5cc81eda37..10a2454c3d 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -309,9 +309,7 @@ impl EditToolRequest { } self.action_log - .update(cx, |log, cx| { - log.notify_buffers_changed(self.changed_buffers, cx) - }) + .update(cx, |log, cx| log.buffer_edited(self.changed_buffers, cx)) .log_err(); let errors = self.parser.errors(); diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index d4c69df8e5..e1f012be44 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -49,7 +49,7 @@ impl Tool for ReadFileTool { input: serde_json::Value, _messages: &[LanguageModelRequestMessage], project: Entity, - _action_log: Entity, + action_log: Entity, cx: &mut App, ) -> Task> { let input = match serde_json::from_value::(input) { @@ -60,14 +60,15 @@ impl Tool for ReadFileTool { let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { return Task::ready(Err(anyhow!("Path not found in project"))); }; - cx.spawn(|cx| async move { + + cx.spawn(|mut cx| async move { let buffer = cx .update(|cx| { project.update(cx, |project, cx| project.open_buffer(project_path, cx)) })? .await?; - buffer.read_with(&cx, |buffer, _cx| { + let result = buffer.read_with(&cx, |buffer, _cx| { if buffer .file() .map_or(false, |file| file.disk_state().exists()) @@ -76,7 +77,13 @@ impl Tool for ReadFileTool { } else { Err(anyhow!("File does not exist")) } - })? + })??; + + action_log.update(&mut cx, |log, cx| { + log.buffer_read(buffer, cx); + })?; + + anyhow::Ok(result) }) } }