agent: Send stale file notifications using the project_notifications tool (#34005)
This commit introduces the `project_notifications` tool, which proactively pushes notifications to the agent. Unlike other tools, `Thread` automatically invokes this tool on every turn, even when the LLM doesn't ask for it. When notifications are available, the tool use and results are inserted into the thread, simulating an LLM tool call. As with other tools, users can disable `project_notifications` in Profiles if they do not want them. Currently, the tool only notifies users about stale files: that is, files that have been edited by the user while the agent is also working on them. In the future, notifications may be expanded to include compiler diagnostics, long-running processes, and more. Release Notes: - Added `project_notifications` tool
This commit is contained in:
parent
de9053c7ca
commit
d87603dd60
8 changed files with 427 additions and 4 deletions
|
@ -810,6 +810,7 @@
|
|||
"edit_file": true,
|
||||
"fetch": true,
|
||||
"list_directory": true,
|
||||
"project_notifications": true,
|
||||
"move_path": true,
|
||||
"now": true,
|
||||
"find_path": true,
|
||||
|
@ -829,6 +830,7 @@
|
|||
"diagnostics": true,
|
||||
"fetch": true,
|
||||
"list_directory": true,
|
||||
"project_notifications": true,
|
||||
"now": true,
|
||||
"find_path": true,
|
||||
"read_file": true,
|
||||
|
|
3
crates/agent/src/prompts/stale_files_prompt_header.txt
Normal file
3
crates/agent/src/prompts/stale_files_prompt_header.txt
Normal file
|
@ -0,0 +1,3 @@
|
|||
[The following is an auto-generated notification; do not reply]
|
||||
|
||||
These files have changed since the last read:
|
|
@ -25,8 +25,8 @@ use language_model::{
|
|||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
|
||||
Role, SelectedModel, StopReason, TokenUsage,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
|
||||
PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
|
||||
};
|
||||
use postage::stream::Stream as _;
|
||||
use project::{
|
||||
|
@ -45,7 +45,7 @@ use std::{
|
|||
time::{Duration, Instant},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use util::{ResultExt as _, post_inc};
|
||||
use util::{ResultExt as _, debug_panic, post_inc};
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
|
||||
|
@ -1248,6 +1248,8 @@ impl Thread {
|
|||
|
||||
self.remaining_turns -= 1;
|
||||
|
||||
self.flush_notifications(model.clone(), intent, cx);
|
||||
|
||||
let request = self.to_completion_request(model.clone(), intent, cx);
|
||||
|
||||
self.stream_completion(request, model, intent, window, cx);
|
||||
|
@ -1481,6 +1483,110 @@ impl Thread {
|
|||
request
|
||||
}
|
||||
|
||||
/// Insert auto-generated notifications (if any) to the thread
|
||||
fn flush_notifications(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
intent: CompletionIntent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match intent {
|
||||
CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
|
||||
if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
|
||||
cx.emit(ThreadEvent::ToolFinished {
|
||||
tool_use_id: pending_tool_use.id.clone(),
|
||||
pending_tool_use: Some(pending_tool_use),
|
||||
});
|
||||
}
|
||||
}
|
||||
CompletionIntent::ThreadSummarization
|
||||
| CompletionIntent::ThreadContextSummarization
|
||||
| CompletionIntent::CreateFile
|
||||
| CompletionIntent::EditFile
|
||||
| CompletionIntent::InlineAssist
|
||||
| CompletionIntent::TerminalInlineAssist
|
||||
| CompletionIntent::GenerateGitCommitMessage => {}
|
||||
};
|
||||
}
|
||||
|
||||
fn attach_tracked_files_state(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut App,
|
||||
) -> Option<PendingToolUse> {
|
||||
let action_log = self.action_log.read(cx);
|
||||
|
||||
action_log.stale_buffers(cx).next()?;
|
||||
|
||||
// Represent notification as a simulated `project_notifications` tool call
|
||||
let tool_name = Arc::from("project_notifications");
|
||||
let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
|
||||
debug_panic!("`project_notifications` tool not found");
|
||||
return None;
|
||||
};
|
||||
|
||||
if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let input = serde_json::json!({});
|
||||
let request = Arc::new(LanguageModelRequest::default()); // unused
|
||||
let window = None;
|
||||
let tool_result = tool.run(
|
||||
input,
|
||||
request,
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
model.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
let tool_use_id =
|
||||
LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
|
||||
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: tool_use_id.clone(),
|
||||
name: tool_name.clone(),
|
||||
raw_input: "{}".to_string(),
|
||||
input: serde_json::json!({}),
|
||||
is_input_complete: true,
|
||||
};
|
||||
|
||||
let tool_output = cx.background_executor().block(tool_result.output);
|
||||
|
||||
// Attach a project_notification tool call to the latest existing
|
||||
// Assistant message. We cannot create a new Assistant message
|
||||
// because thinking models require a `thinking` block that we
|
||||
// cannot mock. We cannot send a notification as a normal
|
||||
// (non-tool-use) User message because this distracts Agent
|
||||
// too much.
|
||||
let tool_message_id = self
|
||||
.messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.rfind(|(_, message)| message.role == Role::Assistant)
|
||||
.map(|(_, message)| message.id)?;
|
||||
|
||||
let tool_use_metadata = ToolUseMetadata {
|
||||
model: model.clone(),
|
||||
thread_id: self.id.clone(),
|
||||
prompt_id: self.last_prompt_id.clone(),
|
||||
};
|
||||
|
||||
self.tool_use
|
||||
.request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
|
||||
|
||||
let pending_tool_use = self.tool_use.insert_tool_output(
|
||||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
tool_output,
|
||||
self.configured_model.as_ref(),
|
||||
);
|
||||
|
||||
pending_tool_use
|
||||
}
|
||||
|
||||
pub fn stream_completion(
|
||||
&mut self,
|
||||
request: LanguageModelRequest,
|
||||
|
@ -3156,10 +3262,13 @@ mod tests {
|
|||
const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
|
||||
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
|
||||
use assistant_tool::ToolRegistry;
|
||||
use assistant_tools;
|
||||
use futures::StreamExt;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::stream::BoxStream;
|
||||
use gpui::TestAppContext;
|
||||
use http_client;
|
||||
use indoc::indoc;
|
||||
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
|
||||
use language_model::{
|
||||
LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
|
||||
|
@ -3487,6 +3596,105 @@ fn main() {{
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (_workspace, _thread_store, thread, context_store, model) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Add a buffer to the context. This will be a tracked buffer
|
||||
let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = context_store
|
||||
.read_with(cx, |store, _| store.context().next().cloned())
|
||||
.unwrap();
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(vec![context], &project, &None, cx))
|
||||
.await;
|
||||
|
||||
// Insert user message and assistant response
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
|
||||
thread.insert_assistant_message(
|
||||
vec![MessageSegment::Text("This code prints 42.".into())],
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// We shouldn't have a stale buffer notification yet
|
||||
let notification = thread.read_with(cx, |thread, _| {
|
||||
find_tool_use(thread, "project_notifications")
|
||||
});
|
||||
assert!(
|
||||
notification.is_none(),
|
||||
"Should not have stale buffer notification before buffer is modified"
|
||||
);
|
||||
|
||||
// Modify the buffer
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(
|
||||
[(1..1, "\n println!(\"Added a new line\");\n")],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// Insert another user message
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(
|
||||
"What does the code do now?",
|
||||
ContextLoadResult::default(),
|
||||
None,
|
||||
Vec::new(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Check for the stale buffer warning
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||
});
|
||||
|
||||
let Some(notification_result) = thread.read_with(cx, |thread, _cx| {
|
||||
find_tool_use(thread, "project_notifications")
|
||||
}) else {
|
||||
panic!("Should have a `project_notifications` tool use");
|
||||
};
|
||||
|
||||
let Some(notification_content) = notification_result.content.to_str() else {
|
||||
panic!("`project_notifications` should return text");
|
||||
};
|
||||
|
||||
let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
|
||||
|
||||
These files have changed since the last read:
|
||||
- code.rs
|
||||
"};
|
||||
assert_eq!(notification_content, expected_content);
|
||||
}
|
||||
|
||||
fn find_tool_use(thread: &Thread, tool_name: &str) -> Option<LanguageModelToolResult> {
|
||||
thread
|
||||
.messages()
|
||||
.filter_map(|message| {
|
||||
thread
|
||||
.tool_results_for_message(message.id)
|
||||
.into_iter()
|
||||
.find(|result| result.tool_name == tool_name.into())
|
||||
})
|
||||
.next()
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
@ -5052,6 +5260,14 @@ fn main() {{
|
|||
language_model::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
ToolRegistry::default_global(cx);
|
||||
assistant_tool::init(cx);
|
||||
|
||||
let http_client = Arc::new(http_client::HttpClientWithUrl::new(
|
||||
http_client::FakeHttpClient::with_200_response(),
|
||||
"http://localhost".to_string(),
|
||||
None,
|
||||
));
|
||||
assistant_tools::init(http_client, cx);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ mod list_directory_tool;
|
|||
mod move_path_tool;
|
||||
mod now_tool;
|
||||
mod open_tool;
|
||||
mod project_notifications_tool;
|
||||
mod read_file_tool;
|
||||
mod schema;
|
||||
mod templates;
|
||||
|
@ -45,6 +46,7 @@ pub use edit_file_tool::{EditFileMode, EditFileToolInput};
|
|||
pub use find_path_tool::FindPathToolInput;
|
||||
pub use grep_tool::{GrepTool, GrepToolInput};
|
||||
pub use open_tool::OpenTool;
|
||||
pub use project_notifications_tool::ProjectNotificationsTool;
|
||||
pub use read_file_tool::{ReadFileTool, ReadFileToolInput};
|
||||
pub use terminal_tool::TerminalTool;
|
||||
|
||||
|
@ -61,6 +63,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
|||
registry.register_tool(ListDirectoryTool);
|
||||
registry.register_tool(NowTool);
|
||||
registry.register_tool(OpenTool);
|
||||
registry.register_tool(ProjectNotificationsTool);
|
||||
registry.register_tool(FindPathTool);
|
||||
registry.register_tool(ReadFileTool);
|
||||
registry.register_tool(GrepTool);
|
||||
|
|
193
crates/assistant_tools/src/project_notifications_tool.rs
Normal file
193
crates/assistant_tools/src/project_notifications_tool.rs
Normal file
|
@ -0,0 +1,193 @@
|
|||
use crate::schema::json_schema_for;
|
||||
use anyhow::Result;
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write as _;
|
||||
use std::sync::Arc;
|
||||
use ui::IconName;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ProjectUpdatesToolInput {}
|
||||
|
||||
pub struct ProjectNotificationsTool;
|
||||
|
||||
impl Tool for ProjectNotificationsTool {
|
||||
fn name(&self) -> String {
|
||||
"project_notifications".to_string()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn description(&self) -> String {
|
||||
include_str!("./project_notifications_tool/description.md").to_string()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Envelope
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<ProjectUpdatesToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Check project notifications".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let mut stale_files = String::new();
|
||||
|
||||
let action_log = action_log.read(cx);
|
||||
|
||||
for stale_file in action_log.stale_buffers(cx) {
|
||||
if let Some(file) = stale_file.read(cx).file() {
|
||||
writeln!(&mut stale_files, "- {}", file.path().display()).ok();
|
||||
}
|
||||
}
|
||||
|
||||
let response = if stale_files.is_empty() {
|
||||
"No new notifications".to_string()
|
||||
} else {
|
||||
// NOTE: Changes to this prompt require a symmetric update in the LLM Worker
|
||||
const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
|
||||
format!("{HEADER}{stale_files}").replace("\r\n", "\n")
|
||||
};
|
||||
|
||||
Task::ready(Ok(response.into())).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assistant_tool::ToolResultContent;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::sync::Arc;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/test"),
|
||||
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
let buffer_path = project
|
||||
.read_with(cx, |project, cx| {
|
||||
project.find_project_path("test/code.rs", cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(buffer_path.clone(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Start tracking the buffer
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer.clone(), cx);
|
||||
});
|
||||
|
||||
// Run the tool before any changes
|
||||
let tool = Arc::new(ProjectNotificationsTool);
|
||||
let provider = Arc::new(FakeLanguageModelProvider);
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
|
||||
let request = Arc::new(LanguageModelRequest::default());
|
||||
let tool_input = json!({});
|
||||
|
||||
let result = cx.update(|cx| {
|
||||
tool.clone().run(
|
||||
tool_input.clone(),
|
||||
request.clone(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let response = result.output.await.unwrap();
|
||||
let response_text = match &response.content {
|
||||
ToolResultContent::Text(text) => text.clone(),
|
||||
_ => panic!("Expected text response"),
|
||||
};
|
||||
assert_eq!(
|
||||
response_text.as_str(),
|
||||
"No new notifications",
|
||||
"Tool should return 'No new notifications' when no stale buffers"
|
||||
);
|
||||
|
||||
// Modify the buffer (makes it stale)
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit([(1..1, "\nChange!\n")], None, cx);
|
||||
});
|
||||
|
||||
// Run the tool again
|
||||
let result = cx.update(|cx| {
|
||||
tool.run(
|
||||
tool_input.clone(),
|
||||
request.clone(),
|
||||
project.clone(),
|
||||
action_log,
|
||||
model.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// This time the buffer is stale, so the tool should return a notification
|
||||
let response = result.output.await.unwrap();
|
||||
let response_text = match &response.content {
|
||||
ToolResultContent::Text(text) => text.clone(),
|
||||
_ => panic!("Expected text response"),
|
||||
};
|
||||
|
||||
let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n";
|
||||
assert_eq!(
|
||||
response_text.as_str(),
|
||||
expected_content,
|
||||
"Tool should return the stale buffer notification"
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
assistant_tool::init(cx);
|
||||
});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
This tool reports which files have been modified by the user since the agent last accessed them.
|
||||
|
||||
It serves as a notification mechanism to inform the agent of recent changes. No immediate action is required in response to these updates.
|
|
@ -0,0 +1,3 @@
|
|||
[The following is an auto-generated notification; do not reply]
|
||||
|
||||
These files have changed since the last read:
|
|
@ -14,7 +14,7 @@ impl Example for FileChangeNotificationExample {
|
|||
url: "https://github.com/octocat/hello-world".to_string(),
|
||||
revision: "7fd1a60b01f91b314f59955a4e4d4e80d8edf11d".to_string(),
|
||||
language_server: None,
|
||||
max_assertions: Some(1),
|
||||
max_assertions: None,
|
||||
profile_id: AgentProfileId::default(),
|
||||
existing_thread_json: None,
|
||||
max_turns: Some(3),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue