diff --git a/assets/settings/default.json b/assets/settings/default.json index 985e322cac..48cdd665e1 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -810,6 +810,7 @@ "edit_file": true, "fetch": true, "list_directory": true, + "project_notifications": true, "move_path": true, "now": true, "find_path": true, @@ -829,6 +830,7 @@ "diagnostics": true, "fetch": true, "list_directory": true, + "project_notifications": true, "now": true, "find_path": true, "read_file": true, diff --git a/crates/agent/src/prompts/stale_files_prompt_header.txt b/crates/agent/src/prompts/stale_files_prompt_header.txt new file mode 100644 index 0000000000..f743e239c8 --- /dev/null +++ b/crates/agent/src/prompts/stale_files_prompt_header.txt @@ -0,0 +1,3 @@ +[The following is an auto-generated notification; do not reply] + +These files have changed since the last read: diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 815b9e86ea..50d2a4d773 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -25,8 +25,8 @@ use language_model::{ ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, - LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError, - Role, SelectedModel, StopReason, TokenUsage, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, + PaymentRequiredError, Role, SelectedModel, StopReason, TokenUsage, }; use postage::stream::Stream as _; use project::{ @@ -45,7 +45,7 @@ use std::{ time::{Duration, Instant}, }; use thiserror::Error; -use util::{ResultExt as _, post_inc}; +use util::{ResultExt as _, debug_panic, post_inc}; use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; @@ -1248,6 +1248,8 @@ impl Thread { self.remaining_turns -= 1; + self.flush_notifications(model.clone(), intent, cx); + let request = self.to_completion_request(model.clone(), intent, cx); self.stream_completion(request, model, intent, window, cx); @@ -1481,6 +1483,110 @@ impl Thread { request } + /// Insert auto-generated notifications (if any) to the thread + fn flush_notifications( + &mut self, + model: Arc, + intent: CompletionIntent, + cx: &mut Context, + ) { + match intent { + CompletionIntent::UserPrompt | CompletionIntent::ToolResults => { + if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) { + cx.emit(ThreadEvent::ToolFinished { + tool_use_id: pending_tool_use.id.clone(), + pending_tool_use: Some(pending_tool_use), + }); + } + } + CompletionIntent::ThreadSummarization + | CompletionIntent::ThreadContextSummarization + | CompletionIntent::CreateFile + | CompletionIntent::EditFile + | CompletionIntent::InlineAssist + | CompletionIntent::TerminalInlineAssist + | CompletionIntent::GenerateGitCommitMessage => {} + }; + } + + fn attach_tracked_files_state( + &mut self, + model: Arc, + cx: &mut App, + ) -> Option { + let action_log = self.action_log.read(cx); + + action_log.stale_buffers(cx).next()?; + + // Represent notification as a simulated `project_notifications` tool call + let tool_name = Arc::from("project_notifications"); + let Some(tool) = self.tools.read(cx).tool(&tool_name, cx) else { + debug_panic!("`project_notifications` tool not found"); + return None; + }; + + if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) { + return None; + } + + let input = serde_json::json!({}); + let request = Arc::new(LanguageModelRequest::default()); // unused + let window = None; + let tool_result = tool.run( + input, + request, + self.project.clone(), + self.action_log.clone(), + model.clone(), + window, + cx, + ); + + let tool_use_id = + LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len())); + + let tool_use = LanguageModelToolUse { + id: tool_use_id.clone(), + name: tool_name.clone(), + raw_input: "{}".to_string(), + input: serde_json::json!({}), + is_input_complete: true, + }; + + let tool_output = cx.background_executor().block(tool_result.output); + + // Attach a project_notification tool call to the latest existing + // Assistant message. We cannot create a new Assistant message + // because thinking models require a `thinking` block that we + // cannot mock. We cannot send a notification as a normal + // (non-tool-use) User message because this distracts Agent + // too much. + let tool_message_id = self + .messages + .iter() + .enumerate() + .rfind(|(_, message)| message.role == Role::Assistant) + .map(|(_, message)| message.id)?; + + let tool_use_metadata = ToolUseMetadata { + model: model.clone(), + thread_id: self.id.clone(), + prompt_id: self.last_prompt_id.clone(), + }; + + self.tool_use + .request_tool_use(tool_message_id, tool_use, tool_use_metadata.clone(), cx); + + let pending_tool_use = self.tool_use.insert_tool_output( + tool_use_id.clone(), + tool_name, + tool_output, + self.configured_model.as_ref(), + ); + + pending_tool_use + } + pub fn stream_completion( &mut self, request: LanguageModelRequest, @@ -3156,10 +3262,13 @@ mod tests { const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters}; use assistant_tool::ToolRegistry; + use assistant_tools; use futures::StreamExt; use futures::future::BoxFuture; use futures::stream::BoxStream; use gpui::TestAppContext; + use http_client; + use indoc::indoc; use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; use language_model::{ LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId, @@ -3487,6 +3596,105 @@ fn main() {{ ); } + #[gpui::test] + async fn test_stale_buffer_notification(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, _thread_store, thread, context_store, model) = + setup_test_environment(cx, project.clone()).await; + + // Add a buffer to the context. This will be a tracked buffer + let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx) + .await + .unwrap(); + + let context = context_store + .read_with(cx, |store, _| store.context().next().cloned()) + .unwrap(); + let loaded_context = cx + .update(|cx| load_context(vec![context], &project, &None, cx)) + .await; + + // Insert user message and assistant response + thread.update(cx, |thread, cx| { + thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx); + thread.insert_assistant_message( + vec![MessageSegment::Text("This code prints 42.".into())], + cx, + ); + }); + + // We shouldn't have a stale buffer notification yet + let notification = thread.read_with(cx, |thread, _| { + find_tool_use(thread, "project_notifications") + }); + assert!( + notification.is_none(), + "Should not have stale buffer notification before buffer is modified" + ); + + // Modify the buffer + buffer.update(cx, |buffer, cx| { + buffer.edit( + [(1..1, "\n println!(\"Added a new line\");\n")], + None, + cx, + ); + }); + + // Insert another user message + thread.update(cx, |thread, cx| { + thread.insert_user_message( + "What does the code do now?", + ContextLoadResult::default(), + None, + Vec::new(), + cx, + ) + }); + + // Check for the stale buffer warning + thread.update(cx, |thread, cx| { + thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx) + }); + + let Some(notification_result) = thread.read_with(cx, |thread, _cx| { + find_tool_use(thread, "project_notifications") + }) else { + panic!("Should have a `project_notifications` tool use"); + }; + + let Some(notification_content) = notification_result.content.to_str() else { + panic!("`project_notifications` should return text"); + }; + + let expected_content = indoc! {"[The following is an auto-generated notification; do not reply] + + These files have changed since the last read: + - code.rs + "}; + assert_eq!(notification_content, expected_content); + } + + fn find_tool_use(thread: &Thread, tool_name: &str) -> Option { + thread + .messages() + .filter_map(|message| { + thread + .tool_results_for_message(message.id) + .into_iter() + .find(|result| result.tool_name == tool_name.into()) + }) + .next() + .cloned() + } + #[gpui::test] async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) { init_test_settings(cx); @@ -5052,6 +5260,14 @@ fn main() {{ language_model::init_settings(cx); ThemeSettings::register(cx); ToolRegistry::default_global(cx); + assistant_tool::init(cx); + + let http_client = Arc::new(http_client::HttpClientWithUrl::new( + http_client::FakeHttpClient::with_200_response(), + "http://localhost".to_string(), + None, + )); + assistant_tools::init(http_client, cx); }); } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 83312a07b6..eef792f526 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -11,6 +11,7 @@ mod list_directory_tool; mod move_path_tool; mod now_tool; mod open_tool; +mod project_notifications_tool; mod read_file_tool; mod schema; mod templates; @@ -45,6 +46,7 @@ pub use edit_file_tool::{EditFileMode, EditFileToolInput}; pub use find_path_tool::FindPathToolInput; pub use grep_tool::{GrepTool, GrepToolInput}; pub use open_tool::OpenTool; +pub use project_notifications_tool::ProjectNotificationsTool; pub use read_file_tool::{ReadFileTool, ReadFileToolInput}; pub use terminal_tool::TerminalTool; @@ -61,6 +63,7 @@ pub fn init(http_client: Arc, cx: &mut App) { registry.register_tool(ListDirectoryTool); registry.register_tool(NowTool); registry.register_tool(OpenTool); + registry.register_tool(ProjectNotificationsTool); registry.register_tool(FindPathTool); registry.register_tool(ReadFileTool); registry.register_tool(GrepTool); diff --git a/crates/assistant_tools/src/project_notifications_tool.rs b/crates/assistant_tools/src/project_notifications_tool.rs new file mode 100644 index 0000000000..552ebb3d53 --- /dev/null +++ b/crates/assistant_tools/src/project_notifications_tool.rs @@ -0,0 +1,193 @@ +use crate::schema::json_schema_for; +use anyhow::Result; +use assistant_tool::{ActionLog, Tool, ToolResult}; +use gpui::{AnyWindowHandle, App, Entity, Task}; +use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::fmt::Write as _; +use std::sync::Arc; +use ui::IconName; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct ProjectUpdatesToolInput {} + +pub struct ProjectNotificationsTool; + +impl Tool for ProjectNotificationsTool { + fn name(&self) -> String { + "project_notifications".to_string() + } + + fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + false + } + fn may_perform_edits(&self) -> bool { + false + } + fn description(&self) -> String { + include_str!("./project_notifications_tool/description.md").to_string() + } + + fn icon(&self) -> IconName { + IconName::Envelope + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + json_schema_for::(format) + } + + fn ui_text(&self, _input: &serde_json::Value) -> String { + "Check project notifications".into() + } + + fn run( + self: Arc, + _input: serde_json::Value, + _request: Arc, + _project: Entity, + action_log: Entity, + _model: Arc, + _window: Option, + cx: &mut App, + ) -> ToolResult { + let mut stale_files = String::new(); + + let action_log = action_log.read(cx); + + for stale_file in action_log.stale_buffers(cx) { + if let Some(file) = stale_file.read(cx).file() { + writeln!(&mut stale_files, "- {}", file.path().display()).ok(); + } + } + + let response = if stale_files.is_empty() { + "No new notifications".to_string() + } else { + // NOTE: Changes to this prompt require a symmetric update in the LLM Worker + const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt"); + format!("{HEADER}{stale_files}").replace("\r\n", "\n") + }; + + Task::ready(Ok(response.into())).into() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use assistant_tool::ToolResultContent; + use gpui::{AppContext, TestAppContext}; + use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use std::sync::Arc; + use util::path; + + #[gpui::test] + async fn test_stale_buffer_notification(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/test"), + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let buffer_path = project + .read_with(cx, |project, cx| { + project.find_project_path("test/code.rs", cx) + }) + .unwrap(); + + let buffer = project + .update(cx, |project, cx| { + project.open_buffer(buffer_path.clone(), cx) + }) + .await + .unwrap(); + + // Start tracking the buffer + action_log.update(cx, |log, cx| { + log.buffer_read(buffer.clone(), cx); + }); + + // Run the tool before any changes + let tool = Arc::new(ProjectNotificationsTool); + let provider = Arc::new(FakeLanguageModelProvider); + let model: Arc = Arc::new(provider.test_model()); + let request = Arc::new(LanguageModelRequest::default()); + let tool_input = json!({}); + + let result = cx.update(|cx| { + tool.clone().run( + tool_input.clone(), + request.clone(), + project.clone(), + action_log.clone(), + model.clone(), + None, + cx, + ) + }); + + let response = result.output.await.unwrap(); + let response_text = match &response.content { + ToolResultContent::Text(text) => text.clone(), + _ => panic!("Expected text response"), + }; + assert_eq!( + response_text.as_str(), + "No new notifications", + "Tool should return 'No new notifications' when no stale buffers" + ); + + // Modify the buffer (makes it stale) + buffer.update(cx, |buffer, cx| { + buffer.edit([(1..1, "\nChange!\n")], None, cx); + }); + + // Run the tool again + let result = cx.update(|cx| { + tool.run( + tool_input.clone(), + request.clone(), + project.clone(), + action_log, + model.clone(), + None, + cx, + ) + }); + + // This time the buffer is stale, so the tool should return a notification + let response = result.output.await.unwrap(); + let response_text = match &response.content { + ToolResultContent::Text(text) => text.clone(), + _ => panic!("Expected text response"), + }; + + let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n"; + assert_eq!( + response_text.as_str(), + expected_content, + "Tool should return the stale buffer notification" + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + assistant_tool::init(cx); + }); + } +} diff --git a/crates/assistant_tools/src/project_notifications_tool/description.md b/crates/assistant_tools/src/project_notifications_tool/description.md new file mode 100644 index 0000000000..24ff678f5e --- /dev/null +++ b/crates/assistant_tools/src/project_notifications_tool/description.md @@ -0,0 +1,3 @@ +This tool reports which files have been modified by the user since the agent last accessed them. + +It serves as a notification mechanism to inform the agent of recent changes. No immediate action is required in response to these updates. diff --git a/crates/assistant_tools/src/project_notifications_tool/prompt_header.txt b/crates/assistant_tools/src/project_notifications_tool/prompt_header.txt new file mode 100644 index 0000000000..f743e239c8 --- /dev/null +++ b/crates/assistant_tools/src/project_notifications_tool/prompt_header.txt @@ -0,0 +1,3 @@ +[The following is an auto-generated notification; do not reply] + +These files have changed since the last read: diff --git a/crates/eval/src/examples/file_change_notification.rs b/crates/eval/src/examples/file_change_notification.rs index 0e4f770a67..7879ad6f2e 100644 --- a/crates/eval/src/examples/file_change_notification.rs +++ b/crates/eval/src/examples/file_change_notification.rs @@ -14,7 +14,7 @@ impl Example for FileChangeNotificationExample { url: "https://github.com/octocat/hello-world".to_string(), revision: "7fd1a60b01f91b314f59955a4e4d4e80d8edf11d".to_string(), language_server: None, - max_assertions: Some(1), + max_assertions: None, profile_id: AgentProfileId::default(), existing_thread_json: None, max_turns: Some(3),