agent: Include grep-related instructions in the prompt only if the tool is available (#29536)
This change updates the system prompt to conditionally include `grep`-related instructions based on whether the `grep` tool is enabled. Implementation details: 1. Add a `has_tool` handlebars helper. 2. Pass the `model` to all locations where the prompt is built. 3. Use `{{#if has_tool "grep"}}` in the system prompt to gate `grep`-specific instructions. Testing: - Unit tests for the `hasTool` helper. - Unit tests to verify that `grep`-related instructions are included / omitted from the prompt as appropriate. - Manual agent evaluation: - Setup: Asked the Agent "List all impls of MyTrait in the project" using a custom "No tools" profile (all tools disabled). - Before the change: The Agent attempted to call `grep`, encountered an error, then realized the tool was unavailable. - After the change: The Agent immediately asked to enable a search tool. Note: in principle, `grep`/`read_file` tool descriptions alone might be enough, but to confirm this we need more evaluation. If it turns out to be true, we'll be able to remove grep-specific instructions from the system prompt and undo this change. Release Notes: - N/A
This commit is contained in:
parent
0e477e7db9
commit
99df1190a9
5 changed files with 209 additions and 46 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -11114,6 +11114,7 @@ dependencies = [
|
||||||
"paths",
|
"paths",
|
||||||
"rope",
|
"rope",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"text",
|
"text",
|
||||||
"util",
|
"util",
|
||||||
"uuid",
|
"uuid",
|
||||||
|
|
|
@ -27,13 +27,14 @@ If appropriate, use tool calls to explore the current project, which contains th
|
||||||
- `{{root_name}}`
|
- `{{root_name}}`
|
||||||
{{/each}}
|
{{/each}}
|
||||||
|
|
||||||
|
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||||
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
|
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
|
||||||
|
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
|
||||||
|
{{# if (has_tool 'grep') }}
|
||||||
- When looking for symbols in the project, prefer the `grep` tool.
|
- When looking for symbols in the project, prefer the `grep` tool.
|
||||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
|
||||||
{{! TODO: Only mention tools if they are enabled }}
|
|
||||||
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
|
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
|
||||||
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
|
{{/if}}
|
||||||
|
|
||||||
## Fixing Diagnostics
|
## Fixing Diagnostics
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ use language_model::{
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||||
use prompt_store::PromptBuilder;
|
use prompt_store::{ModelContext, PromptBuilder};
|
||||||
use proto::Plan;
|
use proto::Plan;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -740,6 +740,32 @@ impl Thread {
|
||||||
self.tool_use.tool_result_card(id).cloned()
|
self.tool_use.tool_result_card(id).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return tools that are both enabled and supported by the model
|
||||||
|
pub fn available_tools(
|
||||||
|
&self,
|
||||||
|
cx: &App,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
) -> Vec<LanguageModelRequestTool> {
|
||||||
|
if model.supports_tools() {
|
||||||
|
self.tools()
|
||||||
|
.read(cx)
|
||||||
|
.enabled_tools(cx)
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|tool| {
|
||||||
|
// Skip tools that cannot be supported
|
||||||
|
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
|
||||||
|
Some(LanguageModelRequestTool {
|
||||||
|
name: tool.name(),
|
||||||
|
description: tool.description(),
|
||||||
|
input_schema,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
Vec::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn insert_user_message(
|
pub fn insert_user_message(
|
||||||
&mut self,
|
&mut self,
|
||||||
text: impl Into<String>,
|
text: impl Into<String>,
|
||||||
|
@ -941,30 +967,7 @@ impl Thread {
|
||||||
|
|
||||||
self.remaining_turns -= 1;
|
self.remaining_turns -= 1;
|
||||||
|
|
||||||
let mut request = self.to_completion_request(cx);
|
let request = self.to_completion_request(model.clone(), cx);
|
||||||
request.mode = if model.supports_max_mode() {
|
|
||||||
self.completion_mode
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if model.supports_tools() {
|
|
||||||
request.tools = self
|
|
||||||
.tools()
|
|
||||||
.read(cx)
|
|
||||||
.enabled_tools(cx)
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|tool| {
|
|
||||||
// Skip tools that cannot be supported
|
|
||||||
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
|
|
||||||
Some(LanguageModelRequestTool {
|
|
||||||
name: tool.name(),
|
|
||||||
description: tool.description(),
|
|
||||||
input_schema,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
}
|
|
||||||
|
|
||||||
self.stream_completion(request, model, window, cx);
|
self.stream_completion(request, model, window, cx);
|
||||||
}
|
}
|
||||||
|
@ -981,7 +984,11 @@ impl Thread {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_completion_request(&self, cx: &mut Context<Self>) -> LanguageModelRequest {
|
pub fn to_completion_request(
|
||||||
|
&self,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> LanguageModelRequest {
|
||||||
let mut request = LanguageModelRequest {
|
let mut request = LanguageModelRequest {
|
||||||
thread_id: Some(self.id.to_string()),
|
thread_id: Some(self.id.to_string()),
|
||||||
prompt_id: Some(self.last_prompt_id.to_string()),
|
prompt_id: Some(self.last_prompt_id.to_string()),
|
||||||
|
@ -992,10 +999,20 @@ impl Thread {
|
||||||
temperature: None,
|
temperature: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let available_tools = self.available_tools(cx, model.clone());
|
||||||
|
let available_tool_names = available_tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| tool.name.clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let model_context = &ModelContext {
|
||||||
|
available_tools: available_tool_names,
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(project_context) = self.project_context.borrow().as_ref() {
|
if let Some(project_context) = self.project_context.borrow().as_ref() {
|
||||||
match self
|
match self
|
||||||
.prompt_builder
|
.prompt_builder
|
||||||
.generate_assistant_system_prompt(project_context)
|
.generate_assistant_system_prompt(project_context, model_context)
|
||||||
{
|
{
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let message = format!("{err:?}").into();
|
let message = format!("{err:?}").into();
|
||||||
|
@ -1075,6 +1092,13 @@ impl Thread {
|
||||||
|
|
||||||
self.attached_tracked_files_state(&mut request.messages, cx);
|
self.attached_tracked_files_state(&mut request.messages, cx);
|
||||||
|
|
||||||
|
request.tools = available_tools;
|
||||||
|
request.mode = if model.supports_max_mode() {
|
||||||
|
self.completion_mode
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
request
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1376,7 +1400,7 @@ impl Thread {
|
||||||
match result.as_ref() {
|
match result.as_ref() {
|
||||||
Ok(stop_reason) => match stop_reason {
|
Ok(stop_reason) => match stop_reason {
|
||||||
StopReason::ToolUse => {
|
StopReason::ToolUse => {
|
||||||
let tool_uses = thread.use_pending_tools(window, cx);
|
let tool_uses = thread.use_pending_tools(window, cx, model.clone());
|
||||||
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
|
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
|
||||||
}
|
}
|
||||||
StopReason::EndTurn => {}
|
StopReason::EndTurn => {}
|
||||||
|
@ -1594,9 +1618,10 @@ impl Thread {
|
||||||
&mut self,
|
&mut self,
|
||||||
window: Option<AnyWindowHandle>,
|
window: Option<AnyWindowHandle>,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
) -> Vec<PendingToolUse> {
|
) -> Vec<PendingToolUse> {
|
||||||
self.auto_capture_telemetry(cx);
|
self.auto_capture_telemetry(cx);
|
||||||
let request = self.to_completion_request(cx);
|
let request = self.to_completion_request(model, cx);
|
||||||
let messages = Arc::new(request.messages);
|
let messages = Arc::new(request.messages);
|
||||||
let pending_tool_uses = self
|
let pending_tool_uses = self
|
||||||
.tool_use
|
.tool_use
|
||||||
|
@ -2316,9 +2341,11 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
|
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
|
||||||
use assistant_settings::AssistantSettings;
|
use assistant_settings::AssistantSettings;
|
||||||
|
use assistant_tool::ToolRegistry;
|
||||||
use context_server::ContextServerSettings;
|
use context_server::ContextServerSettings;
|
||||||
use editor::EditorSettings;
|
use editor::EditorSettings;
|
||||||
use gpui::TestAppContext;
|
use gpui::TestAppContext;
|
||||||
|
use language_model::fake_provider::FakeLanguageModel;
|
||||||
use project::{FakeFs, Project};
|
use project::{FakeFs, Project};
|
||||||
use prompt_store::PromptBuilder;
|
use prompt_store::PromptBuilder;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
@ -2338,7 +2365,7 @@ mod tests {
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let (_workspace, _thread_store, thread, context_store) =
|
let (_workspace, _thread_store, thread, context_store, model) =
|
||||||
setup_test_environment(cx, project.clone()).await;
|
setup_test_environment(cx, project.clone()).await;
|
||||||
|
|
||||||
add_file_to_context(&project, &context_store, "test/code.rs", cx)
|
add_file_to_context(&project, &context_store, "test/code.rs", cx)
|
||||||
|
@ -2389,7 +2416,9 @@ fn main() {{
|
||||||
assert_eq!(message.loaded_context.text, expected_context);
|
assert_eq!(message.loaded_context.text, expected_context);
|
||||||
|
|
||||||
// Check message in request
|
// Check message in request
|
||||||
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
let request = thread.update(cx, |thread, cx| {
|
||||||
|
thread.to_completion_request(model.clone(), cx)
|
||||||
|
});
|
||||||
|
|
||||||
assert_eq!(request.messages.len(), 2);
|
assert_eq!(request.messages.len(), 2);
|
||||||
let expected_full_message = format!("{}Please explain this code", expected_context);
|
let expected_full_message = format!("{}Please explain this code", expected_context);
|
||||||
|
@ -2410,7 +2439,7 @@ fn main() {{
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let (_, _thread_store, thread, context_store) =
|
let (_, _thread_store, thread, context_store, model) =
|
||||||
setup_test_environment(cx, project.clone()).await;
|
setup_test_environment(cx, project.clone()).await;
|
||||||
|
|
||||||
// First message with context 1
|
// First message with context 1
|
||||||
|
@ -2481,7 +2510,9 @@ fn main() {{
|
||||||
assert!(message3.loaded_context.text.contains("file3.rs"));
|
assert!(message3.loaded_context.text.contains("file3.rs"));
|
||||||
|
|
||||||
// Check entire request to make sure all contexts are properly included
|
// Check entire request to make sure all contexts are properly included
|
||||||
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
let request = thread.update(cx, |thread, cx| {
|
||||||
|
thread.to_completion_request(model.clone(), cx)
|
||||||
|
});
|
||||||
|
|
||||||
// The request should contain all 3 messages
|
// The request should contain all 3 messages
|
||||||
assert_eq!(request.messages.len(), 4);
|
assert_eq!(request.messages.len(), 4);
|
||||||
|
@ -2510,7 +2541,7 @@ fn main() {{
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let (_, _thread_store, thread, _context_store) =
|
let (_, _thread_store, thread, _context_store, model) =
|
||||||
setup_test_environment(cx, project.clone()).await;
|
setup_test_environment(cx, project.clone()).await;
|
||||||
|
|
||||||
// Insert user message without any context (empty context vector)
|
// Insert user message without any context (empty context vector)
|
||||||
|
@ -2536,7 +2567,9 @@ fn main() {{
|
||||||
assert_eq!(message.loaded_context.text, "");
|
assert_eq!(message.loaded_context.text, "");
|
||||||
|
|
||||||
// Check message in request
|
// Check message in request
|
||||||
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
let request = thread.update(cx, |thread, cx| {
|
||||||
|
thread.to_completion_request(model.clone(), cx)
|
||||||
|
});
|
||||||
|
|
||||||
assert_eq!(request.messages.len(), 2);
|
assert_eq!(request.messages.len(), 2);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -2559,7 +2592,9 @@ fn main() {{
|
||||||
assert_eq!(message2.loaded_context.text, "");
|
assert_eq!(message2.loaded_context.text, "");
|
||||||
|
|
||||||
// Check that both messages appear in the request
|
// Check that both messages appear in the request
|
||||||
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
let request = thread.update(cx, |thread, cx| {
|
||||||
|
thread.to_completion_request(model.clone(), cx)
|
||||||
|
});
|
||||||
|
|
||||||
assert_eq!(request.messages.len(), 3);
|
assert_eq!(request.messages.len(), 3);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -2582,7 +2617,7 @@ fn main() {{
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let (_workspace, _thread_store, thread, context_store) =
|
let (_workspace, _thread_store, thread, context_store, model) =
|
||||||
setup_test_environment(cx, project.clone()).await;
|
setup_test_environment(cx, project.clone()).await;
|
||||||
|
|
||||||
// Open buffer and add it to context
|
// Open buffer and add it to context
|
||||||
|
@ -2601,7 +2636,9 @@ fn main() {{
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create a request and check that it doesn't have a stale buffer warning yet
|
// Create a request and check that it doesn't have a stale buffer warning yet
|
||||||
let initial_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
let initial_request = thread.update(cx, |thread, cx| {
|
||||||
|
thread.to_completion_request(model.clone(), cx)
|
||||||
|
});
|
||||||
|
|
||||||
// Make sure we don't have a stale file warning yet
|
// Make sure we don't have a stale file warning yet
|
||||||
let has_stale_warning = initial_request.messages.iter().any(|msg| {
|
let has_stale_warning = initial_request.messages.iter().any(|msg| {
|
||||||
|
@ -2634,7 +2671,9 @@ fn main() {{
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create a new request and check for the stale buffer warning
|
// Create a new request and check for the stale buffer warning
|
||||||
let new_request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
let new_request = thread.update(cx, |thread, cx| {
|
||||||
|
thread.to_completion_request(model.clone(), cx)
|
||||||
|
});
|
||||||
|
|
||||||
// We should have a stale file warning as the last message
|
// We should have a stale file warning as the last message
|
||||||
let last_message = new_request
|
let last_message = new_request
|
||||||
|
@ -2667,6 +2706,7 @@ fn main() {{
|
||||||
ThemeSettings::register(cx);
|
ThemeSettings::register(cx);
|
||||||
ContextServerSettings::register(cx);
|
ContextServerSettings::register(cx);
|
||||||
EditorSettings::register(cx);
|
EditorSettings::register(cx);
|
||||||
|
ToolRegistry::default_global(cx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2688,6 +2728,7 @@ fn main() {{
|
||||||
Entity<ThreadStore>,
|
Entity<ThreadStore>,
|
||||||
Entity<Thread>,
|
Entity<Thread>,
|
||||||
Entity<ContextStore>,
|
Entity<ContextStore>,
|
||||||
|
Arc<dyn LanguageModel>,
|
||||||
) {
|
) {
|
||||||
let (workspace, cx) =
|
let (workspace, cx) =
|
||||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||||
|
@ -2708,7 +2749,10 @@ fn main() {{
|
||||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||||
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
||||||
|
|
||||||
(workspace, thread_store, thread, context_store)
|
let model = FakeLanguageModel::default();
|
||||||
|
let model: Arc<dyn LanguageModel> = Arc::new(model);
|
||||||
|
|
||||||
|
(workspace, thread_store, thread, context_store, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add_file_to_context(
|
async fn add_file_to_context(
|
||||||
|
|
|
@ -28,6 +28,7 @@ parking_lot.workspace = true
|
||||||
paths.workspace = true
|
paths.workspace = true
|
||||||
rope.workspace = true
|
rope.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
text.workspace = true
|
text.workspace = true
|
||||||
util.workspace = true
|
util.workspace = true
|
||||||
uuid.workspace = true
|
uuid.workspace = true
|
||||||
|
|
|
@ -48,6 +48,20 @@ impl ProjectContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct ModelContext {
|
||||||
|
pub available_tools: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct PromptTemplateContext {
|
||||||
|
#[serde(flatten)]
|
||||||
|
project: ProjectContext,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
model: ModelContext,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
pub struct UserRulesContext {
|
pub struct UserRulesContext {
|
||||||
pub uuid: UserPromptId,
|
pub uuid: UserPromptId,
|
||||||
|
@ -124,9 +138,40 @@ impl PromptBuilder {
|
||||||
.unwrap_or_else(|| Arc::new(Self::new(None).unwrap()))
|
.unwrap_or_else(|| Arc::new(Self::new(None).unwrap()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper function for handlebars templates to check if a specific tool is enabled
|
||||||
|
fn has_tool_helper(
|
||||||
|
h: &handlebars::Helper,
|
||||||
|
_: &Handlebars,
|
||||||
|
ctx: &handlebars::Context,
|
||||||
|
_: &mut handlebars::RenderContext,
|
||||||
|
out: &mut dyn handlebars::Output,
|
||||||
|
) -> handlebars::HelperResult {
|
||||||
|
let tool_name = h.param(0).and_then(|v| v.value().as_str()).ok_or_else(|| {
|
||||||
|
handlebars::RenderError::new("has_tool helper: missing or invalid tool name parameter")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let enabled_tools = ctx
|
||||||
|
.data()
|
||||||
|
.get("available_tools")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.map(|arr| arr.iter().filter_map(|v| v.as_str()).collect::<Vec<&str>>())
|
||||||
|
.ok_or_else(|| {
|
||||||
|
handlebars::RenderError::new(
|
||||||
|
"has_tool handlebars helper: available_tools not found or not an array",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if enabled_tools.contains(&tool_name) {
|
||||||
|
out.write("true")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new(loading_params: Option<PromptLoadingParams>) -> Result<Self> {
|
pub fn new(loading_params: Option<PromptLoadingParams>) -> Result<Self> {
|
||||||
let mut handlebars = Handlebars::new();
|
let mut handlebars = Handlebars::new();
|
||||||
Self::register_built_in_templates(&mut handlebars)?;
|
Self::register_built_in_templates(&mut handlebars)?;
|
||||||
|
handlebars.register_helper("has_tool", Box::new(Self::has_tool_helper));
|
||||||
|
|
||||||
let handlebars = Arc::new(Mutex::new(handlebars));
|
let handlebars = Arc::new(Mutex::new(handlebars));
|
||||||
|
|
||||||
|
@ -278,10 +323,16 @@ impl PromptBuilder {
|
||||||
pub fn generate_assistant_system_prompt(
|
pub fn generate_assistant_system_prompt(
|
||||||
&self,
|
&self,
|
||||||
context: &ProjectContext,
|
context: &ProjectContext,
|
||||||
|
model_context: &ModelContext,
|
||||||
) -> Result<String, RenderError> {
|
) -> Result<String, RenderError> {
|
||||||
|
let template_context = PromptTemplateContext {
|
||||||
|
project: context.clone(),
|
||||||
|
model: model_context.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
self.handlebars
|
self.handlebars
|
||||||
.lock()
|
.lock()
|
||||||
.render("assistant_system_prompt", context)
|
.render("assistant_system_prompt", &template_context)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn generate_inline_transformation_prompt(
|
pub fn generate_inline_transformation_prompt(
|
||||||
|
@ -398,6 +449,7 @@ impl PromptBuilder {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use serde_json;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -416,9 +468,73 @@ mod test {
|
||||||
contents: "Rules contents".into(),
|
contents: "Rules contents".into(),
|
||||||
}];
|
}];
|
||||||
let project_context = ProjectContext::new(worktrees, default_user_rules);
|
let project_context = ProjectContext::new(worktrees, default_user_rules);
|
||||||
PromptBuilder::new(None)
|
let model_context = ModelContext {
|
||||||
|
available_tools: ["grep".into()].to_vec(),
|
||||||
|
};
|
||||||
|
let prompt = PromptBuilder::new(None)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.generate_assistant_system_prompt(&project_context)
|
.generate_assistant_system_prompt(&project_context, &model_context)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
assert!(
|
||||||
|
prompt.contains("Rules contents"),
|
||||||
|
"Expected default user rules to be in rendered prompt"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_assistant_system_prompt_depends_on_enabled_tools() {
|
||||||
|
let worktrees = vec![WorktreeContext {
|
||||||
|
root_name: "path".into(),
|
||||||
|
rules_file: None,
|
||||||
|
}];
|
||||||
|
let default_user_rules = vec![];
|
||||||
|
let project_context = ProjectContext::new(worktrees, default_user_rules);
|
||||||
|
let prompt_builder = PromptBuilder::new(None).unwrap();
|
||||||
|
|
||||||
|
// When the `grep` tool is enabled, it should be mentioned in the prompt
|
||||||
|
let model_context = ModelContext {
|
||||||
|
available_tools: ["grep".into()].to_vec(),
|
||||||
|
};
|
||||||
|
let prompt_with_grep = prompt_builder
|
||||||
|
.generate_assistant_system_prompt(&project_context, &model_context)
|
||||||
|
.unwrap();
|
||||||
|
assert!(
|
||||||
|
prompt_with_grep.contains("grep"),
|
||||||
|
"`grep` tool should be mentioned in prompt when the tool is enabled"
|
||||||
|
);
|
||||||
|
|
||||||
|
// When the `grep` tool is disabled, it should not be mentioned in the prompt
|
||||||
|
let model_context = ModelContext {
|
||||||
|
available_tools: [].to_vec(),
|
||||||
|
};
|
||||||
|
let prompt_without_grep = prompt_builder
|
||||||
|
.generate_assistant_system_prompt(&project_context, &model_context)
|
||||||
|
.unwrap();
|
||||||
|
assert!(
|
||||||
|
!prompt_without_grep.contains("grep"),
|
||||||
|
"`grep` tool should not be mentioned in prompt when the tool is disabled"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_has_tool_helper() {
|
||||||
|
let mut handlebars = Handlebars::new();
|
||||||
|
handlebars.register_helper("has_tool", Box::new(PromptBuilder::has_tool_helper));
|
||||||
|
handlebars
|
||||||
|
.register_template_string(
|
||||||
|
"test_template",
|
||||||
|
"{{#if (has_tool 'grep')}}grep is enabled{{else}}grep is disabled{{/if}}",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// grep available
|
||||||
|
let data = serde_json::json!({"available_tools": ["grep", "fetch"]});
|
||||||
|
let result = handlebars.render("test_template", &data).unwrap();
|
||||||
|
assert_eq!(result, "grep is enabled");
|
||||||
|
|
||||||
|
// grep not available
|
||||||
|
let data = serde_json::json!({"available_tools": ["terminal", "fetch"]});
|
||||||
|
let result = handlebars.render("test_template", &data).unwrap();
|
||||||
|
assert_eq!(result, "grep is disabled");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue