tools: Send stale file notifications only once (#34026)
Previously, we sent notifications repeatedly until the agent read a file, which was often inefficient. With this change, we now send a notification only once (unless the files are modified again, in which case we'll send another notification). Release Notes: - N/A
This commit is contained in:
parent
e0c860c42a
commit
d549993c73
3 changed files with 117 additions and 19 deletions
|
@ -1516,7 +1516,7 @@ impl Thread {
|
||||||
) -> Option<PendingToolUse> {
|
) -> Option<PendingToolUse> {
|
||||||
let action_log = self.action_log.read(cx);
|
let action_log = self.action_log.read(cx);
|
||||||
|
|
||||||
action_log.stale_buffers(cx).next()?;
|
action_log.unnotified_stale_buffers(cx).next()?;
|
||||||
|
|
||||||
// Represent notification as a simulated `project_notifications` tool call
|
// Represent notification as a simulated `project_notifications` tool call
|
||||||
let tool_name = Arc::from("project_notifications");
|
let tool_name = Arc::from("project_notifications");
|
||||||
|
@ -3631,11 +3631,11 @@ fn main() {{
|
||||||
});
|
});
|
||||||
|
|
||||||
// We shouldn't have a stale buffer notification yet
|
// We shouldn't have a stale buffer notification yet
|
||||||
let notification = thread.read_with(cx, |thread, _| {
|
let notifications = thread.read_with(cx, |thread, _| {
|
||||||
find_tool_use(thread, "project_notifications")
|
find_tool_uses(thread, "project_notifications")
|
||||||
});
|
});
|
||||||
assert!(
|
assert!(
|
||||||
notification.is_none(),
|
notifications.is_empty(),
|
||||||
"Should not have stale buffer notification before buffer is modified"
|
"Should not have stale buffer notification before buffer is modified"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -3664,13 +3664,15 @@ fn main() {{
|
||||||
thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
|
thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
let Some(notification_result) = thread.read_with(cx, |thread, _cx| {
|
let notifications = thread.read_with(cx, |thread, _cx| {
|
||||||
find_tool_use(thread, "project_notifications")
|
find_tool_uses(thread, "project_notifications")
|
||||||
}) else {
|
});
|
||||||
|
|
||||||
|
let [notification] = notifications.as_slice() else {
|
||||||
panic!("Should have a `project_notifications` tool use");
|
panic!("Should have a `project_notifications` tool use");
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some(notification_content) = notification_result.content.to_str() else {
|
let Some(notification_content) = notification.content.to_str() else {
|
||||||
panic!("`project_notifications` should return text");
|
panic!("`project_notifications` should return text");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3680,19 +3682,46 @@ fn main() {{
|
||||||
- code.rs
|
- code.rs
|
||||||
"};
|
"};
|
||||||
assert_eq!(notification_content, expected_content);
|
assert_eq!(notification_content, expected_content);
|
||||||
|
|
||||||
|
// Insert another user message and flush notifications again
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.insert_user_message(
|
||||||
|
"Can you tell me more?",
|
||||||
|
ContextLoadResult::default(),
|
||||||
|
None,
|
||||||
|
Vec::new(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
thread.update(cx, |thread, cx| {
|
||||||
|
thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
// There should be no new notifications (we already flushed one)
|
||||||
|
let notifications = thread.read_with(cx, |thread, _cx| {
|
||||||
|
find_tool_uses(thread, "project_notifications")
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
notifications.len(),
|
||||||
|
1,
|
||||||
|
"Should still have only one notification after second flush - no duplicates"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_tool_use(thread: &Thread, tool_name: &str) -> Option<LanguageModelToolResult> {
|
fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec<LanguageModelToolResult> {
|
||||||
thread
|
thread
|
||||||
.messages()
|
.messages()
|
||||||
.filter_map(|message| {
|
.flat_map(|message| {
|
||||||
thread
|
thread
|
||||||
.tool_results_for_message(message.id)
|
.tool_results_for_message(message.id)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.find(|result| result.tool_name == tool_name.into())
|
.filter(|result| result.tool_name == tool_name.into())
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>()
|
||||||
})
|
})
|
||||||
.next()
|
.collect()
|
||||||
.cloned()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result};
|
||||||
use buffer_diff::BufferDiff;
|
use buffer_diff::BufferDiff;
|
||||||
|
use clock;
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use futures::{FutureExt, StreamExt, channel::mpsc};
|
use futures::{FutureExt, StreamExt, channel::mpsc};
|
||||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
|
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
|
||||||
|
@ -17,6 +18,8 @@ pub struct ActionLog {
|
||||||
edited_since_project_diagnostics_check: bool,
|
edited_since_project_diagnostics_check: bool,
|
||||||
/// The project this action log is associated with
|
/// The project this action log is associated with
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
|
/// Tracks which buffer versions have already been notified as changed externally
|
||||||
|
notified_versions: BTreeMap<Entity<Buffer>, clock::Global>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ActionLog {
|
impl ActionLog {
|
||||||
|
@ -26,6 +29,7 @@ impl ActionLog {
|
||||||
tracked_buffers: BTreeMap::default(),
|
tracked_buffers: BTreeMap::default(),
|
||||||
edited_since_project_diagnostics_check: false,
|
edited_since_project_diagnostics_check: false,
|
||||||
project,
|
project,
|
||||||
|
notified_versions: BTreeMap::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,6 +55,7 @@ impl ActionLog {
|
||||||
) -> &mut TrackedBuffer {
|
) -> &mut TrackedBuffer {
|
||||||
let status = if is_created {
|
let status = if is_created {
|
||||||
if let Some(tracked) = self.tracked_buffers.remove(&buffer) {
|
if let Some(tracked) = self.tracked_buffers.remove(&buffer) {
|
||||||
|
self.notified_versions.remove(&buffer);
|
||||||
match tracked.status {
|
match tracked.status {
|
||||||
TrackedBufferStatus::Created {
|
TrackedBufferStatus::Created {
|
||||||
existing_file_content,
|
existing_file_content,
|
||||||
|
@ -106,7 +111,7 @@ impl ActionLog {
|
||||||
TrackedBuffer {
|
TrackedBuffer {
|
||||||
buffer: buffer.clone(),
|
buffer: buffer.clone(),
|
||||||
diff_base,
|
diff_base,
|
||||||
unreviewed_edits: unreviewed_edits,
|
unreviewed_edits,
|
||||||
snapshot: text_snapshot.clone(),
|
snapshot: text_snapshot.clone(),
|
||||||
status,
|
status,
|
||||||
version: buffer.read(cx).version(),
|
version: buffer.read(cx).version(),
|
||||||
|
@ -165,6 +170,7 @@ impl ActionLog {
|
||||||
// If the buffer had been edited by a tool, but it got
|
// If the buffer had been edited by a tool, but it got
|
||||||
// deleted externally, we want to stop tracking it.
|
// deleted externally, we want to stop tracking it.
|
||||||
self.tracked_buffers.remove(&buffer);
|
self.tracked_buffers.remove(&buffer);
|
||||||
|
self.notified_versions.remove(&buffer);
|
||||||
}
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
@ -178,6 +184,7 @@ impl ActionLog {
|
||||||
// resurrected externally, we want to clear the edits we
|
// resurrected externally, we want to clear the edits we
|
||||||
// were tracking and reset the buffer's state.
|
// were tracking and reset the buffer's state.
|
||||||
self.tracked_buffers.remove(&buffer);
|
self.tracked_buffers.remove(&buffer);
|
||||||
|
self.notified_versions.remove(&buffer);
|
||||||
self.track_buffer_internal(buffer, false, cx);
|
self.track_buffer_internal(buffer, false, cx);
|
||||||
}
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
@ -483,6 +490,7 @@ impl ActionLog {
|
||||||
match tracked_buffer.status {
|
match tracked_buffer.status {
|
||||||
TrackedBufferStatus::Created { .. } => {
|
TrackedBufferStatus::Created { .. } => {
|
||||||
self.tracked_buffers.remove(&buffer);
|
self.tracked_buffers.remove(&buffer);
|
||||||
|
self.notified_versions.remove(&buffer);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
TrackedBufferStatus::Modified => {
|
TrackedBufferStatus::Modified => {
|
||||||
|
@ -508,6 +516,7 @@ impl ActionLog {
|
||||||
match tracked_buffer.status {
|
match tracked_buffer.status {
|
||||||
TrackedBufferStatus::Deleted => {
|
TrackedBufferStatus::Deleted => {
|
||||||
self.tracked_buffers.remove(&buffer);
|
self.tracked_buffers.remove(&buffer);
|
||||||
|
self.notified_versions.remove(&buffer);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -616,6 +625,7 @@ impl ActionLog {
|
||||||
};
|
};
|
||||||
|
|
||||||
self.tracked_buffers.remove(&buffer);
|
self.tracked_buffers.remove(&buffer);
|
||||||
|
self.notified_versions.remove(&buffer);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
task
|
task
|
||||||
}
|
}
|
||||||
|
@ -629,6 +639,7 @@ impl ActionLog {
|
||||||
|
|
||||||
// Clear all tracked edits for this buffer and start over as if we just read it.
|
// Clear all tracked edits for this buffer and start over as if we just read it.
|
||||||
self.tracked_buffers.remove(&buffer);
|
self.tracked_buffers.remove(&buffer);
|
||||||
|
self.notified_versions.remove(&buffer);
|
||||||
self.buffer_read(buffer.clone(), cx);
|
self.buffer_read(buffer.clone(), cx);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
save
|
save
|
||||||
|
@ -713,6 +724,33 @@ impl ActionLog {
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns stale buffers that haven't been notified yet
|
||||||
|
pub fn unnotified_stale_buffers<'a>(
|
||||||
|
&'a self,
|
||||||
|
cx: &'a App,
|
||||||
|
) -> impl Iterator<Item = &'a Entity<Buffer>> {
|
||||||
|
self.stale_buffers(cx).filter(|buffer| {
|
||||||
|
let buffer_entity = buffer.read(cx);
|
||||||
|
self.notified_versions
|
||||||
|
.get(buffer)
|
||||||
|
.map_or(true, |notified_version| {
|
||||||
|
*notified_version != buffer_entity.version
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Marks the given buffers as notified at their current versions
|
||||||
|
pub fn mark_buffers_as_notified(
|
||||||
|
&mut self,
|
||||||
|
buffers: impl IntoIterator<Item = Entity<Buffer>>,
|
||||||
|
cx: &App,
|
||||||
|
) {
|
||||||
|
for buffer in buffers {
|
||||||
|
let version = buffer.read(cx).version.clone();
|
||||||
|
self.notified_versions.insert(buffer, version);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Iterate over buffers changed since last read or edited by the model
|
/// Iterate over buffers changed since last read or edited by the model
|
||||||
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
|
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
|
||||||
self.tracked_buffers
|
self.tracked_buffers
|
||||||
|
|
|
@ -53,15 +53,21 @@ impl Tool for ProjectNotificationsTool {
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
) -> ToolResult {
|
) -> ToolResult {
|
||||||
let mut stale_files = String::new();
|
let mut stale_files = String::new();
|
||||||
|
let mut notified_buffers = Vec::new();
|
||||||
|
|
||||||
let action_log = action_log.read(cx);
|
for stale_file in action_log.read(cx).unnotified_stale_buffers(cx) {
|
||||||
|
|
||||||
for stale_file in action_log.stale_buffers(cx) {
|
|
||||||
if let Some(file) = stale_file.read(cx).file() {
|
if let Some(file) = stale_file.read(cx).file() {
|
||||||
writeln!(&mut stale_files, "- {}", file.path().display()).ok();
|
writeln!(&mut stale_files, "- {}", file.path().display()).ok();
|
||||||
|
notified_buffers.push(stale_file.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !notified_buffers.is_empty() {
|
||||||
|
action_log.update(cx, |log, cx| {
|
||||||
|
log.mark_buffers_as_notified(notified_buffers, cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let response = if stale_files.is_empty() {
|
let response = if stale_files.is_empty() {
|
||||||
"No new notifications".to_string()
|
"No new notifications".to_string()
|
||||||
} else {
|
} else {
|
||||||
|
@ -155,11 +161,11 @@ mod tests {
|
||||||
|
|
||||||
// Run the tool again
|
// Run the tool again
|
||||||
let result = cx.update(|cx| {
|
let result = cx.update(|cx| {
|
||||||
tool.run(
|
tool.clone().run(
|
||||||
tool_input.clone(),
|
tool_input.clone(),
|
||||||
request.clone(),
|
request.clone(),
|
||||||
project.clone(),
|
project.clone(),
|
||||||
action_log,
|
action_log.clone(),
|
||||||
model.clone(),
|
model.clone(),
|
||||||
None,
|
None,
|
||||||
cx,
|
cx,
|
||||||
|
@ -179,6 +185,31 @@ mod tests {
|
||||||
expected_content,
|
expected_content,
|
||||||
"Tool should return the stale buffer notification"
|
"Tool should return the stale buffer notification"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Run the tool once more without any changes - should get no new notifications
|
||||||
|
let result = cx.update(|cx| {
|
||||||
|
tool.run(
|
||||||
|
tool_input.clone(),
|
||||||
|
request.clone(),
|
||||||
|
project.clone(),
|
||||||
|
action_log,
|
||||||
|
model.clone(),
|
||||||
|
None,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = result.output.await.unwrap();
|
||||||
|
let response_text = match &response.content {
|
||||||
|
ToolResultContent::Text(text) => text.clone(),
|
||||||
|
_ => panic!("Expected text response"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
response_text.as_str(),
|
||||||
|
"No new notifications",
|
||||||
|
"Tool should return 'No new notifications' when running again without changes"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn init_test(cx: &mut TestAppContext) {
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue