From 99df1190a9e72b32e840de991dbbbe1c59d1f5d4 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Mon, 28 Apr 2025 22:47:40 +0300 Subject: [PATCH] agent: Include grep-related instructions in the prompt only if the tool is available (#29536) This change updates the system prompt to conditionally include `grep`-related instructions based on whether the `grep` tool is enabled. Implementation details: 1. Add a `has_tool` handlebars helper. 2. Pass the `model` to all locations where the prompt is built. 3. Use `{{#if has_tool "grep"}}` in the system prompt to gate `grep`-specific instructions. Testing: - Unit tests for the `hasTool` helper. - Unit tests to verify that `grep`-related instructions are included / omitted from the prompt as appropriate. - Manual agent evaluation: - Setup: Asked the Agent "List all impls of MyTrait in the project" using a custom "No tools" profile (all tools disabled). - Before the change: The Agent attempted to call `grep`, encountered an error, then realized the tool was unavailable. - After the change: The Agent immediately asked to enable a search tool. Note: in principle, `grep`/`read_file` tool descriptions alone might be enough, but to confirm this we need more evaluation. If it turns out to be true, we'll be able to remove grep-specific instructions from the system prompt and undo this change. Release Notes: - N/A --- Cargo.lock | 1 + assets/prompts/assistant_system_prompt.hbs | 7 +- crates/agent/src/thread.rs | 124 ++++++++++++++------- crates/prompt_store/Cargo.toml | 1 + crates/prompt_store/src/prompts.rs | 122 +++++++++++++++++++- 5 files changed, 209 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c822baf289..c1a30686df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11114,6 +11114,7 @@ dependencies = [ "paths", "rope", "serde", + "serde_json", "text", "util", "uuid", diff --git a/assets/prompts/assistant_system_prompt.hbs b/assets/prompts/assistant_system_prompt.hbs index 4a3e574e67..a1cabaab69 100644 --- a/assets/prompts/assistant_system_prompt.hbs +++ b/assets/prompts/assistant_system_prompt.hbs @@ -27,13 +27,14 @@ If appropriate, use tool calls to explore the current project, which contains th - `{{root_name}}` {{/each}} +- Bias towards not asking the user for help if you can find the answer yourself. - When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above. +- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path! +{{# if (has_tool 'grep') }} - When looking for symbols in the project, prefer the `grep` tool. - As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project. -- Bias towards not asking the user for help if you can find the answer yourself. -{{! TODO: Only mention tools if they are enabled }} - The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file. -- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path! +{{/if}} ## Fixing Diagnostics diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index f439f1f8fc..4e62dbea29 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -26,7 +26,7 @@ use language_model::{ }; use project::Project; use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState}; -use prompt_store::PromptBuilder; +use prompt_store::{ModelContext, PromptBuilder}; use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -740,6 +740,32 @@ impl Thread { self.tool_use.tool_result_card(id).cloned() } + /// Return tools that are both enabled and supported by the model + pub fn available_tools( + &self, + cx: &App, + model: Arc, + ) -> Vec { + if model.supports_tools() { + self.tools() + .read(cx) + .enabled_tools(cx) + .into_iter() + .filter_map(|tool| { + // Skip tools that cannot be supported + let input_schema = tool.input_schema(model.tool_input_format()).ok()?; + Some(LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema, + }) + }) + .collect() + } else { + Vec::default() + } + } + pub fn insert_user_message( &mut self, text: impl Into, @@ -941,30 +967,7 @@ impl Thread { self.remaining_turns -= 1; - let mut request = self.to_completion_request(cx); - request.mode = if model.supports_max_mode() { - self.completion_mode - } else { - None - }; - - if model.supports_tools() { - request.tools = self - .tools() - .read(cx) - .enabled_tools(cx) - .into_iter() - .filter_map(|tool| { - // Skip tools that cannot be supported - let input_schema = tool.input_schema(model.tool_input_format()).ok()?; - Some(LanguageModelRequestTool { - name: tool.name(), - description: tool.description(), - input_schema, - }) - }) - .collect(); - } + let request = self.to_completion_request(model.clone(), cx); self.stream_completion(request, model, window, cx); } @@ -981,7 +984,11 @@ impl Thread { false } - pub fn to_completion_request(&self, cx: &mut Context) -> LanguageModelRequest { + pub fn to_completion_request( + &self, + model: Arc, + cx: &mut Context, + ) -> LanguageModelRequest { let mut request = LanguageModelRequest { thread_id: Some(self.id.to_string()), prompt_id: Some(self.last_prompt_id.to_string()), @@ -992,10 +999,20 @@ impl Thread { temperature: None, }; + let available_tools = self.available_tools(cx, model.clone()); + let available_tool_names = available_tools + .iter() + .map(|tool| tool.name.clone()) + .collect(); + + let model_context = &ModelContext { + available_tools: available_tool_names, + }; + if let Some(project_context) = self.project_context.borrow().as_ref() { match self .prompt_builder - .generate_assistant_system_prompt(project_context) + .generate_assistant_system_prompt(project_context, model_context) { Err(err) => { let message = format!("{err:?}").into(); @@ -1075,6 +1092,13 @@ impl Thread { self.attached_tracked_files_state(&mut request.messages, cx); + request.tools = available_tools; + request.mode = if model.supports_max_mode() { + self.completion_mode + } else { + None + }; + request } @@ -1376,7 +1400,7 @@ impl Thread { match result.as_ref() { Ok(stop_reason) => match stop_reason { StopReason::ToolUse => { - let tool_uses = thread.use_pending_tools(window, cx); + let tool_uses = thread.use_pending_tools(window, cx, model.clone()); cx.emit(ThreadEvent::UsePendingTools { tool_uses }); } StopReason::EndTurn => {} @@ -1594,9 +1618,10 @@ impl Thread { &mut self, window: Option, cx: &mut Context, + model: Arc, ) -> Vec { self.auto_capture_telemetry(cx); - let request = self.to_completion_request(cx); + let request = self.to_completion_request(model, cx); let messages = Arc::new(request.messages); let pending_tool_uses = self .tool_use @@ -2316,9 +2341,11 @@ mod tests { use super::*; use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; use assistant_settings::AssistantSettings; + use assistant_tool::ToolRegistry; use context_server::ContextServerSettings; use editor::EditorSettings; use gpui::TestAppContext; + use language_model::fake_provider::FakeLanguageModel; use project::{FakeFs, Project}; use prompt_store::PromptBuilder; use serde_json::json; @@ -2338,7 +2365,7 @@ mod tests { ) .await; - let (_workspace, _thread_store, thread, context_store) = + let (_workspace, _thread_store, thread, context_store, model) = setup_test_environment(cx, project.clone()).await; add_file_to_context(&project, &context_store, "test/code.rs", cx) @@ -2389,7 +2416,9 @@ fn main() {{ assert_eq!(message.loaded_context.text, expected_context); // Check message in request - let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); assert_eq!(request.messages.len(), 2); let expected_full_message = format!("{}Please explain this code", expected_context); @@ -2410,7 +2439,7 @@ fn main() {{ ) .await; - let (_, _thread_store, thread, context_store) = + let (_, _thread_store, thread, context_store, model) = setup_test_environment(cx, project.clone()).await; // First message with context 1 @@ -2481,7 +2510,9 @@ fn main() {{ assert!(message3.loaded_context.text.contains("file3.rs")); // Check entire request to make sure all contexts are properly included - let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); // The request should contain all 3 messages assert_eq!(request.messages.len(), 4); @@ -2510,7 +2541,7 @@ fn main() {{ ) .await; - let (_, _thread_store, thread, _context_store) = + let (_, _thread_store, thread, _context_store, model) = setup_test_environment(cx, project.clone()).await; // Insert user message without any context (empty context vector) @@ -2536,7 +2567,9 @@ fn main() {{ assert_eq!(message.loaded_context.text, ""); // Check message in request - let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); assert_eq!(request.messages.len(), 2); assert_eq!( @@ -2559,7 +2592,9 @@ fn main() {{ assert_eq!(message2.loaded_context.text, ""); // Check that both messages appear in the request - let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); + let request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); assert_eq!(request.messages.len(), 3); assert_eq!( @@ -2582,7 +2617,7 @@ fn main() {{ ) .await; - let (_workspace, _thread_store, thread, context_store) = + let (_workspace, _thread_store, thread, context_store, model) = setup_test_environment(cx, project.clone()).await; // Open buffer and add it to context @@ -2601,7 +2636,9 @@ fn main() {{ }); // Create a request and check that it doesn't have a stale buffer warning yet - let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); + let initial_request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); // Make sure we don't have a stale file warning yet let has_stale_warning = initial_request.messages.iter().any(|msg| { @@ -2634,7 +2671,9 @@ fn main() {{ }); // Create a new request and check for the stale buffer warning - let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx)); + let new_request = thread.update(cx, |thread, cx| { + thread.to_completion_request(model.clone(), cx) + }); // We should have a stale file warning as the last message let last_message = new_request @@ -2667,6 +2706,7 @@ fn main() {{ ThemeSettings::register(cx); ContextServerSettings::register(cx); EditorSettings::register(cx); + ToolRegistry::default_global(cx); }); } @@ -2688,6 +2728,7 @@ fn main() {{ Entity, Entity, Entity, + Arc, ) { let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); @@ -2708,7 +2749,10 @@ fn main() {{ let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); - (workspace, thread_store, thread, context_store) + let model = FakeLanguageModel::default(); + let model: Arc = Arc::new(model); + + (workspace, thread_store, thread, context_store, model) } async fn add_file_to_context( diff --git a/crates/prompt_store/Cargo.toml b/crates/prompt_store/Cargo.toml index 7d75f19ecf..d749378138 100644 --- a/crates/prompt_store/Cargo.toml +++ b/crates/prompt_store/Cargo.toml @@ -28,6 +28,7 @@ parking_lot.workspace = true paths.workspace = true rope.workspace = true serde.workspace = true +serde_json.workspace = true text.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 0a53957e36..afcab4758e 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -48,6 +48,20 @@ impl ProjectContext { } } +#[derive(Debug, Clone, Serialize)] +pub struct ModelContext { + pub available_tools: Vec, +} + +#[derive(Serialize)] +struct PromptTemplateContext { + #[serde(flatten)] + project: ProjectContext, + + #[serde(flatten)] + model: ModelContext, +} + #[derive(Debug, Clone, Serialize)] pub struct UserRulesContext { pub uuid: UserPromptId, @@ -124,9 +138,40 @@ impl PromptBuilder { .unwrap_or_else(|| Arc::new(Self::new(None).unwrap())) } + /// Helper function for handlebars templates to check if a specific tool is enabled + fn has_tool_helper( + h: &handlebars::Helper, + _: &Handlebars, + ctx: &handlebars::Context, + _: &mut handlebars::RenderContext, + out: &mut dyn handlebars::Output, + ) -> handlebars::HelperResult { + let tool_name = h.param(0).and_then(|v| v.value().as_str()).ok_or_else(|| { + handlebars::RenderError::new("has_tool helper: missing or invalid tool name parameter") + })?; + + let enabled_tools = ctx + .data() + .get("available_tools") + .and_then(|v| v.as_array()) + .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect::>()) + .ok_or_else(|| { + handlebars::RenderError::new( + "has_tool handlebars helper: available_tools not found or not an array", + ) + })?; + + if enabled_tools.contains(&tool_name) { + out.write("true")?; + } + + Ok(()) + } + pub fn new(loading_params: Option) -> Result { let mut handlebars = Handlebars::new(); Self::register_built_in_templates(&mut handlebars)?; + handlebars.register_helper("has_tool", Box::new(Self::has_tool_helper)); let handlebars = Arc::new(Mutex::new(handlebars)); @@ -278,10 +323,16 @@ impl PromptBuilder { pub fn generate_assistant_system_prompt( &self, context: &ProjectContext, + model_context: &ModelContext, ) -> Result { + let template_context = PromptTemplateContext { + project: context.clone(), + model: model_context.clone(), + }; + self.handlebars .lock() - .render("assistant_system_prompt", context) + .render("assistant_system_prompt", &template_context) } pub fn generate_inline_transformation_prompt( @@ -398,6 +449,7 @@ impl PromptBuilder { #[cfg(test)] mod test { use super::*; + use serde_json; use uuid::Uuid; #[test] @@ -416,9 +468,73 @@ mod test { contents: "Rules contents".into(), }]; let project_context = ProjectContext::new(worktrees, default_user_rules); - PromptBuilder::new(None) + let model_context = ModelContext { + available_tools: ["grep".into()].to_vec(), + }; + let prompt = PromptBuilder::new(None) .unwrap() - .generate_assistant_system_prompt(&project_context) + .generate_assistant_system_prompt(&project_context, &model_context) .unwrap(); + assert!( + prompt.contains("Rules contents"), + "Expected default user rules to be in rendered prompt" + ); + } + + #[test] + fn test_assistant_system_prompt_depends_on_enabled_tools() { + let worktrees = vec![WorktreeContext { + root_name: "path".into(), + rules_file: None, + }]; + let default_user_rules = vec![]; + let project_context = ProjectContext::new(worktrees, default_user_rules); + let prompt_builder = PromptBuilder::new(None).unwrap(); + + // When the `grep` tool is enabled, it should be mentioned in the prompt + let model_context = ModelContext { + available_tools: ["grep".into()].to_vec(), + }; + let prompt_with_grep = prompt_builder + .generate_assistant_system_prompt(&project_context, &model_context) + .unwrap(); + assert!( + prompt_with_grep.contains("grep"), + "`grep` tool should be mentioned in prompt when the tool is enabled" + ); + + // When the `grep` tool is disabled, it should not be mentioned in the prompt + let model_context = ModelContext { + available_tools: [].to_vec(), + }; + let prompt_without_grep = prompt_builder + .generate_assistant_system_prompt(&project_context, &model_context) + .unwrap(); + assert!( + !prompt_without_grep.contains("grep"), + "`grep` tool should not be mentioned in prompt when the tool is disabled" + ); + } + + #[test] + fn test_has_tool_helper() { + let mut handlebars = Handlebars::new(); + handlebars.register_helper("has_tool", Box::new(PromptBuilder::has_tool_helper)); + handlebars + .register_template_string( + "test_template", + "{{#if (has_tool 'grep')}}grep is enabled{{else}}grep is disabled{{/if}}", + ) + .unwrap(); + + // grep available + let data = serde_json::json!({"available_tools": ["grep", "fetch"]}); + let result = handlebars.render("test_template", &data).unwrap(); + assert_eq!(result, "grep is enabled"); + + // grep not available + let data = serde_json::json!({"available_tools": ["terminal", "fetch"]}); + let result = handlebars.render("test_template", &data).unwrap(); + assert_eq!(result, "grep is disabled"); } }