agent: Include grep-related instructions in the prompt only if the tool is available (#29536)

This change updates the system prompt to conditionally include
`grep`-related instructions based on whether the `grep` tool is enabled.

Implementation details:
1. Add a `has_tool` handlebars helper.
2. Pass the `model` to all locations where the prompt is built.
3. Use `{{#if has_tool "grep"}}` in the system prompt to gate
`grep`-specific instructions.

Testing:
- Unit tests for the `hasTool` helper.
- Unit tests to verify that `grep`-related instructions are included /
omitted from the prompt as appropriate.
- Manual agent evaluation:
- Setup: Asked the Agent "List all impls of MyTrait in the project"
using a custom "No tools" profile (all tools disabled).
- Before the change: The Agent attempted to call `grep`, encountered an
error, then realized the tool was unavailable.
- After the change: The Agent immediately asked to enable a search tool.

Note: in principle, `grep`/`read_file` tool descriptions alone might be
enough, but to confirm this we need more evaluation. If it turns out to
be true, we'll be able to remove grep-specific instructions from the
system prompt and undo this change.

Release Notes:

- N/A
This commit is contained in:
Oleksiy Syvokon 2025-04-28 22:47:40 +03:00 committed by GitHub
parent 0e477e7db9
commit 99df1190a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 209 additions and 46 deletions

View file

@ -26,7 +26,7 @@ use language_model::{
};
use project::Project;
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
use prompt_store::PromptBuilder;
use prompt_store::{ModelContext, PromptBuilder};
use proto::Plan;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -740,6 +740,32 @@ impl Thread {
self.tool_use.tool_result_card(id).cloned()
}
/// Return tools that are both enabled and supported by the model
pub fn available_tools(
&self,
cx: &App,
model: Arc<dyn LanguageModel>,
) -> Vec<LanguageModelRequestTool> {
if model.supports_tools() {
self.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
// Skip tools that cannot be supported
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema,
})
})
.collect()
} else {
Vec::default()
}
}
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
@ -941,30 +967,7 @@ impl Thread {
self.remaining_turns -= 1;
let mut request = self.to_completion_request(cx);
request.mode = if model.supports_max_mode() {
self.completion_mode
} else {
None
};
if model.supports_tools() {
request.tools = self
.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
// Skip tools that cannot be supported
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema,
})
})
.collect();
}
let request = self.to_completion_request(model.clone(), cx);
self.stream_completion(request, model, window, cx);
}
@ -981,7 +984,11 @@ impl Thread {
false
}
pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
pub fn to_completion_request(
&self,
model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
thread_id: Some(self.id.to_string()),
prompt_id: Some(self.last_prompt_id.to_string()),
@ -992,10 +999,20 @@ impl Thread {
temperature: None,
};
let available_tools = self.available_tools(cx, model.clone());
let available_tool_names = available_tools
.iter()
.map(|tool| tool.name.clone())
.collect();
let model_context = &ModelContext {
available_tools: available_tool_names,
};
if let Some(project_context) = self.project_context.borrow().as_ref() {
match self
.prompt_builder
.generate_assistant_system_prompt(project_context)
.generate_assistant_system_prompt(project_context, model_context)
{
Err(err) => {
let message = format!("{err:?}").into();
@ -1075,6 +1092,13 @@ impl Thread {
self.attached_tracked_files_state(&mut request.messages, cx);
request.tools = available_tools;
request.mode = if model.supports_max_mode() {
self.completion_mode
} else {
None
};
request
}
@ -1376,7 +1400,7 @@ impl Thread {
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
let tool_uses = thread.use_pending_tools(window, cx);
let tool_uses = thread.use_pending_tools(window, cx, model.clone());
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn => {}
@ -1594,9 +1618,10 @@ impl Thread {
&mut self,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
model: Arc<dyn LanguageModel>,
) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx);
let request = self.to_completion_request(cx);
let request = self.to_completion_request(model, cx);
let messages = Arc::new(request.messages);
let pending_tool_uses = self
.tool_use
@ -2316,9 +2341,11 @@ mod tests {
use super::*;
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::TestAppContext;
use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project};
use prompt_store::PromptBuilder;
use serde_json::json;
@ -2338,7 +2365,7 @@ mod tests {
)
.await;
let (_workspace, _thread_store, thread, context_store) =
let (_workspace, _thread_store, thread, context_store, model) =
setup_test_environment(cx, project.clone()).await;
add_file_to_context(&project, &context_store, "test/code.rs", cx)
@ -2389,7 +2416,9 @@ fn main() {{
assert_eq!(message.loaded_context.text, expected_context);
// Check message in request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.messages.len(), 2);
let expected_full_message = format!("{}Please explain this code", expected_context);
@ -2410,7 +2439,7 @@ fn main() {{
)
.await;
let (_, _thread_store, thread, context_store) =
let (_, _thread_store, thread, context_store, model) =
setup_test_environment(cx, project.clone()).await;
// First message with context 1
@ -2481,7 +2510,9 @@ fn main() {{
assert!(message3.loaded_context.text.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
// The request should contain all 3 messages
assert_eq!(request.messages.len(), 4);
@ -2510,7 +2541,7 @@ fn main() {{
)
.await;
let (_, _thread_store, thread, _context_store) =
let (_, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
// Insert user message without any context (empty context vector)
@ -2536,7 +2567,9 @@ fn main() {{
assert_eq!(message.loaded_context.text, "");
// Check message in request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.messages.len(), 2);
assert_eq!(
@ -2559,7 +2592,9 @@ fn main() {{
assert_eq!(message2.loaded_context.text, "");
// Check that both messages appear in the request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.messages.len(), 3);
assert_eq!(
@ -2582,7 +2617,7 @@ fn main() {{
)
.await;
let (_workspace, _thread_store, thread, context_store) =
let (_workspace, _thread_store, thread, context_store, model) =
setup_test_environment(cx, project.clone()).await;
// Open buffer and add it to context
@ -2601,7 +2636,9 @@ fn main() {{
});
// Create a request and check that it doesn't have a stale buffer warning yet
let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
let initial_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
// Make sure we don't have a stale file warning yet
let has_stale_warning = initial_request.messages.iter().any(|msg| {
@ -2634,7 +2671,9 @@ fn main() {{
});
// Create a new request and check for the stale buffer warning
let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
let new_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), cx)
});
// We should have a stale file warning as the last message
let last_message = new_request
@ -2667,6 +2706,7 @@ fn main() {{
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
ToolRegistry::default_global(cx);
});
}
@ -2688,6 +2728,7 @@ fn main() {{
Entity<ThreadStore>,
Entity<Thread>,
Entity<ContextStore>,
Arc<dyn LanguageModel>,
) {
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
@ -2708,7 +2749,10 @@ fn main() {{
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
(workspace, thread_store, thread, context_store)
let model = FakeLanguageModel::default();
let model: Arc<dyn LanguageModel> = Arc::new(model);
(workspace, thread_store, thread, context_store, model)
}
async fn add_file_to_context(