agent: Send stale file notifications using the project_notifications tool (#34005)
This commit introduces the `project_notifications` tool, which proactively pushes notifications to the agent. Unlike other tools, `Thread` automatically invokes this tool on every turn, even when the LLM doesn't ask for it. When notifications are available, the tool use and results are inserted into the thread, simulating an LLM tool call. As with other tools, users can disable `project_notifications` in Profiles if they do not want them. Currently, the tool only notifies users about stale files: that is, files that have been edited by the user while the agent is also working on them. In the future, notifications may be expanded to include compiler diagnostics, long-running processes, and more. Release Notes: - Added `project_notifications` tool
This commit is contained in:
parent
de9053c7ca
commit
d87603dd60
8 changed files with 427 additions and 4 deletions
3
crates/agent/src/prompts/stale_files_prompt_header.txt
Normal file
3
crates/agent/src/prompts/stale_files_prompt_header.txt
Normal file
|
@ -0,0 +1,3 @@
|
|||
[The following is an auto-generated notification; do not reply]
|
||||
|
||||
These files have changed since the last read:
|
|
@ -25,8 +25,8 @@ use language_model::{
|
|||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
|
||||
Role, SelectedModel, StopReason, TokenUsage,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError,
|
||||
PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage,
|
||||
};
|
||||
use postage::stream::Stream as _;
|
||||
use project::{
|
||||
|
@ -45,7 +45,7 @@ use std::{
|
|||
time::{Duration, Instant},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use util::{ResultExt as _, post_inc};
|
||||
use util::{ResultExt as _, debug_panic, post_inc};
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
|
||||
|
@ -1248,6 +1248,8 @@ impl Thread {
|
|||
|
||||
self.remaining_turns -= 1;
|
||||
|
||||
self.flush_notifications(model.clone(), intent, cx);
|
||||
|
||||
let request = self.to_completion_request(model.clone(), intent, cx);
|
||||
|
||||
self.stream_completion(request, model, intent, window, cx);
|
||||
|
@ -1481,6 +1483,110 @@ impl Thread {
|
|||
request
|
||||
}
|
||||
|
||||
/// Insert auto-generated notifications (if any) to the thread
|
||||
fn flush_notifications(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
intent: CompletionIntent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match intent {
|
||||
CompletionIntent::UserPrompt | CompletionIntent::ToolResults => {
|
||||
if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) {
|
||||
cx.emit(ThreadEvent::ToolFinished {
|
||||
tool_use_id: pending_tool_use.id.clone(),
|
||||
pending_tool_use: Some(pending_tool_use),
|
||||
});
|
||||
}
|
||||
}
|
||||
CompletionIntent::ThreadSummarization
|
||||
| CompletionIntent::ThreadContextSummarization
|
||||
| CompletionIntent::CreateFile
|
||||
| CompletionIntent::EditFile
|
||||
| CompletionIntent::InlineAssist
|
||||
| CompletionIntent::TerminalInlineAssist
|
||||
| CompletionIntent::GenerateGitCommitMessage => {}
|
||||
};
|
||||
}
|
||||
|
||||
fn attach_tracked_files_state(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut App,
|
||||
) -> Option<PendingToolUse> {
|
||||
let action_log = self.action_log.read(cx);
|
||||
|
||||
action_log.stale_buffers(cx).next()?;
|
||||
|
||||
// Represent notification as a simulated `project_notifications` tool call
|
||||
let tool_name = Arc::from("project_notifications");
|
||||
let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else {
|
||||
debug_panic!("`project_notifications` tool not found");
|
||||
return None;
|
||||
};
|
||||
|
||||
if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let input = serde_json::json!({});
|
||||
let request = Arc::new(LanguageModelRequest::default()); // unused
|
||||
let window = None;
|
||||
let tool_result = tool.run(
|
||||
input,
|
||||
request,
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
model.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
let tool_use_id =
|
||||
LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len()));
|
||||
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: tool_use_id.clone(),
|
||||
name: tool_name.clone(),
|
||||
raw_input: "{}".to_string(),
|
||||
input: serde_json::json!({}),
|
||||
is_input_complete: true,
|
||||
};
|
||||
|
||||
let tool_output = cx.background_executor().block(tool_result.output);
|
||||
|
||||
// Attach a project_notification tool call to the latest existing
|
||||
// Assistant message. We cannot create a new Assistant message
|
||||
// because thinking models require a `thinking` block that we
|
||||
// cannot mock. We cannot send a notification as a normal
|
||||
// (non-tool-use) User message because this distracts Agent
|
||||
// too much.
|
||||
let tool_message_id = self
|
||||
.messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.rfind(|(_, message)| message.role == Role::Assistant)
|
||||
.map(|(_, message)| message.id)?;
|
||||
|
||||
let tool_use_metadata = ToolUseMetadata {
|
||||
model: model.clone(),
|
||||
thread_id: self.id.clone(),
|
||||
prompt_id: self.last_prompt_id.clone(),
|
||||
};
|
||||
|
||||
self.tool_use
|
||||
.request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx);
|
||||
|
||||
let pending_tool_use = self.tool_use.insert_tool_output(
|
||||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
tool_output,
|
||||
self.configured_model.as_ref(),
|
||||
);
|
||||
|
||||
pending_tool_use
|
||||
}
|
||||
|
||||
pub fn stream_completion(
|
||||
&mut self,
|
||||
request: LanguageModelRequest,
|
||||
|
@ -3156,10 +3262,13 @@ mod tests {
|
|||
const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30;
|
||||
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
|
||||
use assistant_tool::ToolRegistry;
|
||||
use assistant_tools;
|
||||
use futures::StreamExt;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::stream::BoxStream;
|
||||
use gpui::TestAppContext;
|
||||
use http_client;
|
||||
use indoc::indoc;
|
||||
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
|
||||
use language_model::{
|
||||
LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId,
|
||||
|
@ -3487,6 +3596,105 @@ fn main() {{
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (_workspace, _thread_store, thread, context_store, model) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Add a buffer to the context. This will be a tracked buffer
|
||||
let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context = context_store
|
||||
.read_with(cx, |store, _| store.context().next().cloned())
|
||||
.unwrap();
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(vec![context], &project, &None, cx))
|
||||
.await;
|
||||
|
||||
// Insert user message and assistant response
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx);
|
||||
thread.insert_assistant_message(
|
||||
vec![MessageSegment::Text("This code prints 42.".into())],
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// We shouldn't have a stale buffer notification yet
|
||||
let notification = thread.read_with(cx, |thread, _| {
|
||||
find_tool_use(thread, "project_notifications")
|
||||
});
|
||||
assert!(
|
||||
notification.is_none(),
|
||||
"Should not have stale buffer notification before buffer is modified"
|
||||
);
|
||||
|
||||
// Modify the buffer
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(
|
||||
[(1..1, "\n println!(\"Added a new line\");\n")],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// Insert another user message
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(
|
||||
"What does the code do now?",
|
||||
ContextLoadResult::default(),
|
||||
None,
|
||||
Vec::new(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Check for the stale buffer warning
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||
});
|
||||
|
||||
let Some(notification_result) = thread.read_with(cx, |thread, _cx| {
|
||||
find_tool_use(thread, "project_notifications")
|
||||
}) else {
|
||||
panic!("Should have a `project_notifications` tool use");
|
||||
};
|
||||
|
||||
let Some(notification_content) = notification_result.content.to_str() else {
|
||||
panic!("`project_notifications` should return text");
|
||||
};
|
||||
|
||||
let expected_content = indoc! {"[The following is an auto-generated notification; do not reply]
|
||||
|
||||
These files have changed since the last read:
|
||||
- code.rs
|
||||
"};
|
||||
assert_eq!(notification_content, expected_content);
|
||||
}
|
||||
|
||||
fn find_tool_use(thread: &Thread, tool_name: &str) -> Option<LanguageModelToolResult> {
|
||||
thread
|
||||
.messages()
|
||||
.filter_map(|message| {
|
||||
thread
|
||||
.tool_results_for_message(message.id)
|
||||
.into_iter()
|
||||
.find(|result| result.tool_name == tool_name.into())
|
||||
})
|
||||
.next()
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
@ -5052,6 +5260,14 @@ fn main() {{
|
|||
language_model::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
ToolRegistry::default_global(cx);
|
||||
assistant_tool::init(cx);
|
||||
|
||||
let http_client = Arc::new(http_client::HttpClientWithUrl::new(
|
||||
http_client::FakeHttpClient::with_200_response(),
|
||||
"http://localhost".to_string(),
|
||||
None,
|
||||
));
|
||||
assistant_tools::init(http_client, cx);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue