Add system prompt and tool permission to agent2 (#35781)

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-08-07 15:40:12 +02:00 committed by GitHub
parent 4dbd24d75f
commit 03876d076e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1111 additions and 304 deletions

8
Cargo.lock generated
View file

@ -138,9 +138,9 @@ dependencies = [
[[package]]
name = "agent-client-protocol"
version = "0.0.20"
version = "0.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12dbfec3d27680337ed9d3064eecafe97acf0b0f190148bb4e29d96707c9e403"
checksum = "b7ae3c22c23b64a5c3b7fc8a86fcc7c494e989bd2cd66fdce14a58cfc8078381"
dependencies = [
"anyhow",
"futures 0.3.31",
@ -159,6 +159,7 @@ dependencies = [
"agent-client-protocol",
"agent_servers",
"anyhow",
"assistant_tool",
"client",
"clock",
"cloud_llm_client",
@ -171,10 +172,12 @@ dependencies = [
"gpui_tokio",
"handlebars 4.5.0",
"indoc",
"language",
"language_model",
"language_models",
"log",
"project",
"prompt_store",
"reqwest_client",
"rust-embed",
"schemars",
@ -185,6 +188,7 @@ dependencies = [
"ui",
"util",
"uuid",
"watch",
"workspace-hack",
"worktree",
]

View file

@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
agent-client-protocol = "0.0.20"
agent-client-protocol = "0.0.21"
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"

View file

@ -902,7 +902,7 @@ impl AcpThread {
});
}
pub fn request_tool_call_permission(
pub fn request_tool_call_authorization(
&mut self,
tool_call: acp::ToolCall,
options: Vec<acp::PermissionOption>,

View file

@ -16,6 +16,7 @@ acp_thread.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
fs.workspace = true
@ -27,6 +28,7 @@ language_model.workspace = true
language_models.workspace = true
log.workspace = true
project.workspace = true
prompt_store.workspace = true
rust-embed.workspace = true
schemars.workspace = true
serde.workspace = true
@ -36,6 +38,7 @@ smol.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
watch.workspace = true
worktree.workspace = true
workspace-hack.workspace = true
@ -47,6 +50,7 @@ env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true

View file

@ -1,16 +1,39 @@
use crate::ToolCallAuthorization;
use crate::{templates::Templates, AgentResponseEvent, Thread};
use acp_thread::ModelSelector;
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use futures::StreamExt;
use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity};
use anyhow::{anyhow, Context as _, Result};
use futures::{future, StreamExt};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use project::Project;
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use util::ResultExt;
use crate::{templates::Templates, AgentResponseEvent, Thread};
const RULES_FILE_NAMES: [&'static str; 9] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
"AGENT.md",
"AGENTS.md",
"GEMINI.md",
];
pub struct RulesLoadingError {
pub message: SharedString,
}
/// Holds both the internal Thread and the AcpThread for a session
struct Session {
@ -24,17 +47,247 @@ struct Session {
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
/// Shared project context for all threads
project_context: Rc<RefCell<ProjectContext>>,
project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task<Result<()>>,
/// Shared templates for all threads
templates: Arc<Templates>,
project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
impl NativeAgent {
pub fn new(templates: Arc<Templates>) -> Self {
pub async fn new(
project: Entity<Project>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
cx: &mut AsyncApp,
) -> Result<Entity<NativeAgent>> {
log::info!("Creating new NativeAgent");
let project_context = cx
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
.await;
cx.new(|cx| {
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
watch::channel(());
Self {
sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
}),
templates,
project,
prompt_store,
_subscriptions: subscriptions,
}
})
}
async fn maintain_project_context(
this: WeakEntity<Self>,
mut needs_refresh: watch::Receiver<()>,
cx: &mut AsyncApp,
) -> Result<()> {
while needs_refresh.changed().await.is_ok() {
let project_context = this
.update(cx, |this, cx| {
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
})?
.await;
this.update(cx, |this, _| this.project_context.replace(project_context))?;
}
Ok(())
}
fn build_project_context(
project: &Entity<Project>,
prompt_store: Option<&Entity<PromptStore>>,
cx: &mut App,
) -> Task<ProjectContext> {
let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
let worktree_tasks = worktrees
.into_iter()
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
})
.collect::<Vec<_>>();
let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
prompt_store.read_with(cx, |prompt_store, cx| {
let prompts = prompt_store.default_prompt_metadata();
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
let contents = prompt_store.load(prompt_metadata.id, cx);
async move { (contents.await, prompt_metadata) }
});
cx.background_spawn(future::join_all(load_tasks))
})
} else {
Task::ready(vec![])
};
cx.spawn(async move |_cx| {
let (worktrees, default_user_rules) =
future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
let worktrees = worktrees
.into_iter()
.map(|(worktree, _rules_error)| {
// TODO: show error message
// if let Some(rules_error) = rules_error {
// this.update(cx, |_, cx| cx.emit(rules_error)).ok();
// }
worktree
})
.collect::<Vec<_>>();
let default_user_rules = default_user_rules
.into_iter()
.flat_map(|(contents, prompt_metadata)| match contents {
Ok(contents) => Some(UserRulesContext {
uuid: match prompt_metadata.id {
PromptId::User { uuid } => uuid,
PromptId::EditWorkflow => return None,
},
title: prompt_metadata.title.map(|title| title.to_string()),
contents,
}),
Err(_err) => {
// TODO: show error message
// this.update(cx, |_, cx| {
// cx.emit(RulesLoadingError {
// message: format!("{err:?}").into(),
// });
// })
// .ok();
None
}
})
.collect::<Vec<_>>();
ProjectContext::new(worktrees, default_user_rules)
})
}
fn load_worktree_info_for_system_prompt(
worktree: Entity<Worktree>,
project: Entity<Project>,
cx: &mut App,
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
let tree = worktree.read(cx);
let root_name = tree.root_name().into();
let abs_path = tree.abs_path();
let mut context = WorktreeContext {
root_name,
abs_path,
rules_file: None,
};
let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
let Some(rules_task) = rules_task else {
return Task::ready((context, None));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(RulesLoadingError {
message: format!("{err}").into(),
}),
),
};
context.rules_file = rules_file;
(context, rules_file_error)
})
}
fn load_worktree_rules_file(
worktree: Entity<Worktree>,
project: Entity<Project>,
cx: &mut App,
) -> Option<Task<Result<RulesFileContext>>> {
let worktree = worktree.read(cx);
let worktree_id = worktree.id();
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| entry.path.clone())
})
.next();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|path_in_worktree| {
let project_path = ProjectPath {
worktree_id,
path: path_in_worktree.clone(),
};
let buffer_task =
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
let rope_task = cx.spawn(async move |cx| {
buffer_task.await?.read_with(cx, |buffer, cx| {
let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
})?
});
// Build a string from the rope on a background thread.
cx.background_spawn(async move {
let (project_entry_id, rope) = rope_task.await?;
anyhow::Ok(RulesFileContext {
path_in_worktree,
text: rope.to_string().trim().to_string(),
project_entry_id: project_entry_id.to_usize(),
})
})
})
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,
event: &project::Event,
_cx: &mut Context<Self>,
) {
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
self.project_context_needs_refresh.send(()).ok();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
RULES_FILE_NAMES
.iter()
.any(|name| path.as_ref() == Path::new(name))
}) {
self.project_context_needs_refresh.send(()).ok();
}
}
_ => {}
}
}
fn handle_prompts_updated_event(
&mut self,
_prompt_store: Entity<PromptStore>,
_event: &prompt_store::PromptsUpdatedEvent,
_cx: &mut Context<Self>,
) {
self.project_context_needs_refresh.send(()).ok();
}
}
@ -120,8 +373,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context");
// Generate session ID
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
// Create Thread
let (session_id, thread) = agent.update(
let thread = agent.update(
cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
@ -146,22 +412,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
anyhow!("No default model configured. Please configure a default model in settings.")
})?;
let thread = cx.new(|_| Thread::new(project.clone(), agent.templates.clone(), default_model));
// Generate session ID
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
log::info!("Created session with ID: {}", session_id);
Ok((session_id, thread))
let thread = cx.new(|_| Thread::new(project, agent.project_context.clone(), action_log, agent.templates.clone(), default_model));
Ok(thread)
},
)??;
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
})
})?;
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
@ -264,6 +519,28 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
)
})??;
}
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
}) => {
let recv = acp_thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(tool_call, options, cx)
})?;
cx.background_spawn(async move {
if let Some(option) = recv
.await
.context("authorization sender was dropped")
.log_err()
{
response
.send(option)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
})
.detach();
}
AgentResponseEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.handle_session_update(
@ -343,3 +620,77 @@ fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
message
}
#[cfg(test)]
mod tests {
use super::*;
use fs::FakeFs;
use gpui::TestAppContext;
use serde_json::json;
use settings::SettingsStore;
#[gpui::test]
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/",
json!({
"a": {}
}),
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
.await
.unwrap();
agent.read_with(cx, |agent, _| {
assert_eq!(agent.project_context.borrow().worktrees, vec![])
});
let worktree = project
.update(cx, |project, cx| project.create_worktree("/a", true, cx))
.await
.unwrap();
cx.run_until_parked();
agent.read_with(cx, |agent, _| {
assert_eq!(
agent.project_context.borrow().worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
rules_file: None
}]
)
});
// Creating `/a/.rules` updates the project context.
fs.insert_file("/a/.rules", Vec::new()).await;
cx.run_until_parked();
agent.read_with(cx, |agent, cx| {
let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
assert_eq!(
agent.project_context.borrow().worktrees,
vec![WorktreeContext {
root_name: "a".into(),
abs_path: Path::new("/a").into(),
rules_file: Some(RulesFileContext {
path_in_worktree: Path::new(".rules").into(),
text: "".into(),
project_entry_id: rules_entry.id.to_usize()
})
}]
)
});
}
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
language::init(cx);
});
}
}

View file

@ -1,6 +1,5 @@
mod agent;
mod native_agent_server;
mod prompts;
mod templates;
mod thread;
mod tools;

View file

@ -3,8 +3,9 @@ use std::rc::Rc;
use agent_servers::AgentServer;
use anyhow::Result;
use gpui::{App, AppContext, Entity, Task};
use gpui::{App, Entity, Task};
use project::Project;
use prompt_store::PromptStore;
use crate::{templates::Templates, NativeAgent, NativeAgentConnection};
@ -32,21 +33,22 @@ impl AgentServer for NativeAgentServer {
fn connect(
&self,
_root_dir: &Path,
_project: &Entity<Project>,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
log::info!(
"NativeAgentServer::connect called for path: {:?}",
_root_dir
);
let project = project.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
// Create templates (you might want to load these from files or resources)
let templates = Templates::new();
let prompt_store = prompt_store.await?;
// Create the native agent
log::debug!("Creating native agent entity");
let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?;
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);

View file

@ -1,35 +0,0 @@
use crate::{
templates::{BaseTemplate, Template, Templates, WorktreeData},
thread::Prompt,
};
use anyhow::Result;
use gpui::{App, Entity};
use project::Project;
pub struct BasePrompt {
project: Entity<Project>,
}
impl BasePrompt {
pub fn new(project: Entity<Project>) -> Self {
Self { project }
}
}
impl Prompt for BasePrompt {
fn render(&self, templates: &Templates, cx: &App) -> Result<String> {
BaseTemplate {
os: std::env::consts::OS.to_string(),
shell: util::get_system_shell(),
worktrees: self
.project
.read(cx)
.worktrees(cx)
.map(|worktree| WorktreeData {
root_name: worktree.read(cx).root_name().to_string(),
})
.collect(),
}
.render(templates)
}
}

View file

@ -1,9 +1,9 @@
use std::sync::Arc;
use anyhow::Result;
use gpui::SharedString;
use handlebars::Handlebars;
use rust_embed::RustEmbed;
use serde::Serialize;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "src/templates"]
@ -15,6 +15,8 @@ pub struct Templates(Handlebars<'static>);
impl Templates {
pub fn new() -> Arc<Self> {
let mut handlebars = Handlebars::new();
handlebars.set_strict_mode(true);
handlebars.register_helper("contains", Box::new(contains));
handlebars.register_embed_templates::<Assets>().unwrap();
Arc::new(Self(handlebars))
}
@ -31,22 +33,6 @@ pub trait Template: Sized {
}
}
#[derive(Serialize)]
pub struct BaseTemplate {
pub os: String,
pub shell: String,
pub worktrees: Vec<WorktreeData>,
}
impl Template for BaseTemplate {
const TEMPLATE_NAME: &'static str = "base.hbs";
}
#[derive(Serialize)]
pub struct WorktreeData {
pub root_name: String,
}
#[derive(Serialize)]
pub struct GlobTemplate {
pub project_roots: String,
@ -55,3 +41,56 @@ pub struct GlobTemplate {
impl Template for GlobTemplate {
const TEMPLATE_NAME: &'static str = "glob.hbs";
}
#[derive(Serialize)]
pub struct SystemPromptTemplate<'a> {
#[serde(flatten)]
pub project: &'a prompt_store::ProjectContext,
pub available_tools: Vec<SharedString>,
}
impl Template for SystemPromptTemplate<'_> {
const TEMPLATE_NAME: &'static str = "system_prompt.hbs";
}
/// Handlebars helper for checking if an item is in a list
fn contains(
h: &handlebars::Helper,
_: &handlebars::Handlebars,
_: &handlebars::Context,
_: &mut handlebars::RenderContext,
out: &mut dyn handlebars::Output,
) -> handlebars::HelperResult {
let list = h
.param(0)
.and_then(|v| v.value().as_array())
.ok_or_else(|| {
handlebars::RenderError::new("contains: missing or invalid list parameter")
})?;
let query = h.param(1).map(|v| v.value()).ok_or_else(|| {
handlebars::RenderError::new("contains: missing or invalid query parameter")
})?;
if list.contains(&query) {
out.write("true")?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_system_prompt_template() {
let project = prompt_store::ProjectContext::default();
let template = SystemPromptTemplate {
project: &project,
available_tools: vec!["echo".into()],
};
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
assert!(rendered.contains("## Fixing Diagnostics"));
}
}

View file

@ -1,56 +0,0 @@
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
## Communication
1. Be conversational but professional.
2. Refer to the USER in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
## Tool Use
1. Make sure to adhere to the tools schema.
2. Provide every required argument.
3. DO NOT use tools to access items that are already available in the context section.
4. Use only the tools that are currently available.
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
## Searching and Reading
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
If appropriate, use tool calls to explore the current project, which contains the following root directories:
{{#each worktrees}}
- `{{root_name}}`
{{/each}}
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
- When looking for symbols in the project, prefer the `grep` tool.
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- Bias towards not asking the user for help if you can find the answer yourself.
## Fixing Diagnostics
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
## Debugging
When debugging, only make code changes if you are certain that you can solve the problem.
Otherwise, follow debugging best practices:
1. Address the root cause instead of the symptoms.
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
## Calling External APIs
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
## System Information
Operating System: {{os}}
Default Shell: {{shell}}

View file

@ -0,0 +1,178 @@
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
## Communication
1. Be conversational but professional.
2. Refer to the user in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
{{#if (gt (len available_tools) 0)}}
## Tool Use
1. Make sure to adhere to the tools schema.
2. Provide every required argument.
3. DO NOT use tools to access items that are already available in the context section.
4. Use only the tools that are currently available.
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
6. NEVER run commands that don't terminate on their own such as web servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers.
7. Avoid HTML entity escaping - use plain characters instead.
## Searching and Reading
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
If appropriate, use tool calls to explore the current project, which contains the following root directories:
{{#each worktrees}}
- `{{abs_path}}`
{{/each}}
- Bias towards not asking the user for help if you can find the answer yourself.
- When providing paths to tools, the path should always start with the name of a project root directory listed above.
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
{{# if (contains available_tools 'grep') }}
- When looking for symbols in the project, prefer the `grep` tool.
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
{{/if}}
{{else}}
You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you).
As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally.
The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response.
{{/if}}
## Code Block Formatting
Whenever you mention a code block, you MUST use ONLY use the following format:
```path/to/Something.blah#L123-456
(code goes here)
```
The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah
is a path in the project. (If there is no valid path in the project, then you can use
/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser
does not understand the more common ```language syntax, or bare ``` blocks. It only
understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again.
Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP!
You have made a mistake. You can only ever put paths after triple backticks!
<example>
Based on all the information I've gathered, here's a summary of how this system works:
1. The README file is loaded into the system.
2. The system finds the first two headers, including everything in between. In this case, that would be:
```path/to/README.md#L8-12
# First Header
This is the info under the first header.
## Sub-header
```
3. Then the system finds the last header in the README:
```path/to/README.md#L27-29
## Last Header
This is the last header in the README.
```
4. Finally, it passes this information on to the next process.
</example>
<example>
In Markdown, hash marks signify headings. For example:
```/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</example>
Here are examples of ways you must never render code blocks:
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it does not include the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it has the language instead of the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
# Level 1 heading
## Level 2 heading
### Level 3 heading
</bad_example_do_not_do_this>
This example is unacceptable because it uses indentation to mark the code block
instead of backticks with a path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
{{#if (gt (len available_tools) 0)}}
## Fixing Diagnostics
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
## Debugging
When debugging, only make code changes if you are certain that you can solve the problem.
Otherwise, follow debugging best practices:
1. Address the root cause instead of the symptoms.
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
{{/if}}
## Calling External APIs
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
## System Information
Operating System: {{os}}
Default Shell: {{shell}}
{{#if (or has_rules has_user_rules)}}
## User's Custom Instructions
The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if (gt (len available_tools) 0)}} without interfering with the tool use guidelines{{/if}}.
{{#if has_rules}}
There are project rules that apply to these root directories:
{{#each worktrees}}
{{#if rules_file}}
`{{root_name}}/{{rules_file.path_in_worktree}}`:
``````
{{{rules_file.text}}}
``````
{{/if}}
{{/each}}
{{/if}}
{{#if has_user_rules}}
The user has specified the following rules that should be applied:
{{#each user_rules}}
{{#if title}}
Rules title: {{title}}
{{/if}}
``````
{{contents}}}
``````
{{/each}}
{{/if}}
{{/if}}

View file

@ -1,30 +1,34 @@
use super::*;
use crate::templates::Templates;
use acp_thread::AgentConnection;
use agent_client_protocol as acp;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use assistant_tool::ActionLog;
use client::{Client, UserStore};
use fs::FakeFs;
use futures::channel::mpsc::UnboundedReceiver;
use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
use indoc::indoc;
use language_model::{
fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
StopReason,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
LanguageModelToolUse, MessageContent, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use smol::stream::StreamExt;
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
mod test_tools;
use test_tools::*;
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_echo(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -44,7 +48,7 @@ async fn test_echo(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_thinking(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
@ -77,7 +81,46 @@ async fn test_thinking(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
async fn test_system_prompt(cx: &mut TestAppContext) {
let ThreadTest {
model,
thread,
project_context,
..
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
project_context.borrow_mut().shell = "test-shell".into();
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(
pending_completions.len(),
1,
"unexpected pending completions: {:?}",
pending_completions
);
let pending_completion = pending_completions.pop().unwrap();
assert_eq!(pending_completion.messages[0].role, Role::System);
let system_message = &pending_completion.messages[0];
let system_prompt = system_message.content[0].to_str().unwrap();
assert!(
system_prompt.contains("test-shell"),
"unexpected system message: {:?}",
system_message
);
assert!(
system_prompt.contains("## Fixing Diagnostics"),
"unexpected system message: {:?}",
system_message
);
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -127,7 +170,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -175,7 +218,104 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
async fn test_tool_authorization(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(ToolRequiringPermission);
thread.send(model.clone(), "abc", cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_1".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: "tool_id_2".into(),
name: ToolRequiringPermission.name().into(),
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
},
));
fake_model.end_last_completion_stream();
let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
// Approve the first
tool_call_auth_1
.response
.send(tool_call_auth_1.options[1].id.clone())
.unwrap();
cx.run_until_parked();
// Reject the second
tool_call_auth_2
.response
.send(tool_call_auth_1.options[2].id.clone())
.unwrap();
cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap();
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
tool_name: tool_call_auth_1.tool_call.title.into(),
is_error: false,
content: "Allowed".into(),
output: None
}),
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
tool_name: tool_call_auth_2.tool_call.title.into(),
is_error: true,
content: "Permission to run tool denied by user".into(),
output: None
})
]
);
}
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
) -> ToolCallAuthorization {
loop {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization
.options
.iter()
.map(|o| o.kind)
.collect::<Vec<_>>();
assert_eq!(
permission_kinds,
vec![
acp::PermissionOptionKind::AllowAlways,
acp::PermissionOptionKind::AllowOnce,
acp::PermissionOptionKind::RejectOnce,
]
);
return tool_call_authorization;
}
}
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -214,7 +354,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
}
#[gpui::test]
#[ignore = "temporarily disabled until it can be run on CI"]
#[ignore = "can't run on CI yet"]
async fn test_cancellation(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
@ -281,12 +421,10 @@ async fn test_cancellation(cx: &mut TestAppContext) {
#[gpui::test]
async fn test_refusal(cx: &mut TestAppContext) {
let fake_model = Arc::new(FakeLanguageModel::default());
let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.send(fake_model.clone(), "Hello", cx)
});
let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
@ -343,8 +481,16 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
});
cx.executor().forbid_parking();
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
// Create agent and connection
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Test model_selector returns Some
@ -366,12 +512,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
assert!(!listed_models.is_empty(), "should have at least one model");
assert_eq!(listed_models[0].id().0, "fake");
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
// Create a thread using new_thread
let cwd = Path::new("/test");
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| {
@ -457,12 +598,13 @@ fn stop_events(
struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
project_context: Rc<RefCell<ProjectContext>>,
}
enum TestModel {
Sonnet4,
Sonnet4Thinking,
Fake(Arc<FakeLanguageModel>),
Fake,
}
impl TestModel {
@ -470,7 +612,7 @@ impl TestModel {
match self {
TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
TestModel::Fake(fake_model) => fake_model.id(),
TestModel::Fake => unreachable!(),
}
}
}
@ -499,8 +641,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
if let TestModel::Fake(model) = model {
Task::ready(model as Arc<_>)
if let TestModel::Fake = model {
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
} else {
let model_id = model.id();
let models = LanguageModelRegistry::read_global(cx);
@ -520,9 +662,22 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
})
.await;
let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
ThreadTest { model, thread }
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|_| {
Thread::new(
project,
project_context.clone(),
action_log,
templates,
model.clone(),
)
});
ThreadTest {
model,
thread,
project_context,
}
}
#[cfg(test)]

View file

@ -19,6 +19,10 @@ impl AgentTool for EchoTool {
"echo".into()
}
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
false
}
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
Task::ready(Ok(input.text))
}
@ -40,6 +44,10 @@ impl AgentTool for DelayTool {
"delay".into()
}
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
false
}
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
where
Self: Sized,
@ -51,6 +59,31 @@ impl AgentTool for DelayTool {
}
}
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct ToolRequiringPermissionInput {}
pub struct ToolRequiringPermission;
impl AgentTool for ToolRequiringPermission {
type Input = ToolRequiringPermissionInput;
fn name(&self) -> SharedString {
"tool_requiring_permission".into()
}
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
true
}
fn run(self: Arc<Self>, _input: Self::Input, cx: &mut App) -> Task<Result<String>>
where
Self: Sized,
{
cx.foreground_executor()
.spawn(async move { Ok("Allowed".to_string()) })
}
}
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct InfiniteToolInput {}
@ -63,6 +96,10 @@ impl AgentTool for InfiniteTool {
"infinite".into()
}
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
false
}
fn run(self: Arc<Self>, _input: Self::Input, cx: &mut App) -> Task<Result<String>> {
cx.foreground_executor().spawn(async move {
future::pending::<()>().await;
@ -100,6 +137,10 @@ impl AgentTool for WordListTool {
"word_list".into()
}
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
false
}
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
Task::ready(Ok("ok".to_string()))
}

View file

@ -1,9 +1,13 @@
use crate::{prompts::BasePrompt, templates::Templates};
use crate::templates::{SystemPromptTemplate, Template, Templates};
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::ActionLog;
use cloud_llm_client::{CompletionIntent, CompletionMode};
use collections::HashMap;
use futures::{channel::mpsc, stream::FuturesUnordered};
use futures::{
channel::{mpsc, oneshot},
stream::FuturesUnordered,
};
use gpui::{App, Context, Entity, ImageFormat, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
@ -13,10 +17,11 @@ use language_model::{
};
use log;
use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::Deserialize;
use smol::stream::StreamExt;
use std::{collections::BTreeMap, fmt::Write, sync::Arc};
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
use util::{markdown::MarkdownCodeBlock, ResultExt};
#[derive(Debug, Clone)]
@ -97,11 +102,15 @@ pub enum AgentResponseEvent {
Thinking(String),
ToolCall(acp::ToolCall),
ToolCallUpdate(acp::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
Stop(acp::StopReason),
}
pub trait Prompt {
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
#[derive(Debug)]
pub struct ToolCallAuthorization {
pub tool_call: acp::ToolCall,
pub options: Vec<acp::PermissionOption>,
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
pub struct Thread {
@ -112,28 +121,31 @@ pub struct Thread {
/// we run tools, report their results.
running_turn: Option<Task<()>>,
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
system_prompts: Vec<Arc<dyn Prompt>>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
// action_log: Entity<ActionLog>,
_action_log: Entity<ActionLog>,
}
impl Thread {
pub fn new(
project: Entity<Project>,
_project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
default_model: Arc<dyn LanguageModel>,
) -> Self {
Self {
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
system_prompts: vec![Arc::new(BasePrompt::new(project))],
running_turn: None,
pending_tool_uses: HashMap::default(),
tools: BTreeMap::default(),
project_context,
templates,
selected_model: default_model,
_action_log: action_log,
}
}
@ -188,6 +200,7 @@ impl Thread {
cx.notify();
let (events_tx, events_rx) =
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
let event_stream = AgentResponseEventStream(events_tx);
let user_message_ix = self.messages.len();
self.messages.push(AgentMessage {
@ -222,12 +235,7 @@ impl Thread {
while let Some(event) = events.next().await {
match event {
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
if let Some(reason) = to_acp_stop_reason(reason) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::Stop(reason)))
.ok();
}
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
thread.update(cx, |thread, _cx| {
thread.messages.truncate(user_message_ix);
@ -240,14 +248,16 @@ impl Thread {
thread
.update(cx, |thread, cx| {
tool_uses.extend(thread.handle_streamed_completion_event(
event, &events_tx, cx,
event,
&event_stream,
cx,
));
})
.ok();
}
Err(error) => {
log::error!("Error in completion stream: {:?}", error);
events_tx.unbounded_send(Err(error)).ok();
event_stream.send_error(error);
break;
}
}
@ -266,11 +276,7 @@ impl Thread {
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
to_acp_tool_call_update(&tool_result),
)))
.ok();
event_stream.send_tool_call_result(&tool_result);
thread
.update(cx, |thread, _cx| {
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
@ -291,7 +297,7 @@ impl Thread {
if let Err(error) = turn_result {
log::error!("Turn execution failed: {:?}", error);
events_tx.unbounded_send(Err(error)).ok();
event_stream.send_error(error);
} else {
log::info!("Turn execution completed successfully");
}
@ -299,33 +305,29 @@ impl Thread {
events_rx
}
pub fn build_system_message(&self, cx: &App) -> Option<AgentMessage> {
pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message");
let mut system_message = AgentMessage {
let prompt = SystemPromptTemplate {
project: &self.project_context.borrow(),
available_tools: self.tools.keys().cloned().collect(),
}
.render(&self.templates)
.context("failed to build system prompt")
.expect("Invalid template");
log::debug!("System message built");
AgentMessage {
role: Role::System,
content: Vec::new(),
};
for prompt in &self.system_prompts {
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
system_message
.content
.push(MessageContent::Text(rendered_prompt));
content: vec![prompt.into()],
}
}
let result = (!system_message.content.is_empty()).then_some(system_message);
log::debug!("System message built: {}", result.is_some());
result
}
/// A helper method that's called on every streamed completion event.
/// Returns an optional tool result task, which the main agentic loop in
/// send will send back to the model when it resolves.
fn handle_streamed_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
log::trace!("Handling streamed completion event: {:?}", event);
@ -338,13 +340,13 @@ impl Thread {
content: Vec::new(),
});
}
Text(new_text) => self.handle_text_event(new_text, events_tx, cx),
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
Thinking { text, signature } => {
self.handle_thinking_event(text, signature, events_tx, cx)
self.handle_thinking_event(text, signature, event_stream, cx)
}
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
ToolUse(tool_use) => {
return self.handle_tool_use_event(tool_use, events_tx, cx);
return self.handle_tool_use_event(tool_use, event_stream, cx);
}
ToolUseJsonParseError {
id,
@ -369,12 +371,10 @@ impl Thread {
fn handle_text_event(
&mut self,
new_text: String,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
events_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::Text(new_text.clone())))
.ok();
events_stream.send_text(&new_text);
let last_message = self.last_assistant_message();
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
@ -390,12 +390,10 @@ impl Thread {
&mut self,
new_text: String,
new_signature: Option<String>,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::Thinking(new_text.clone())))
.ok();
event_stream.send_thinking(&new_text);
let last_message = self.last_assistant_message();
if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
@ -423,7 +421,7 @@ impl Thread {
fn handle_tool_use_event(
&mut self,
tool_use: LanguageModelToolUse,
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
cx.notify();
@ -446,32 +444,18 @@ impl Thread {
}
});
if push_new_tool_use {
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall {
id: acp::ToolCallId(tool_use.id.to_string().into()),
title: tool_use.name.to_string(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(tool_use.input.clone()),
})))
.ok();
event_stream.send_tool_call(&tool_use);
last_message
.content
.push(MessageContent::ToolUse(tool_use.clone()));
} else {
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use.id.to_string().into()),
fields: acp::ToolCallUpdateFields {
event_stream.send_tool_call_update(
&tool_use.id,
acp::ToolCallUpdateFields {
raw_input: Some(tool_use.input.clone()),
..Default::default()
},
},
)))
.ok();
);
}
if !tool_use.is_input_complete {
@ -479,22 +463,10 @@ impl Thread {
}
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
events_tx
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use.id.to_string().into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
},
},
)))
.ok();
let pending_tool_result = tool.clone().run(tool_use.input, cx);
let tool_result =
self.run_tool(tool.clone(), tool_use.clone(), event_stream.clone(), cx);
Some(cx.foreground_executor().spawn(async move {
match pending_tool_result.await {
match tool_result.await {
Ok(tool_output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
@ -523,6 +495,30 @@ impl Thread {
}
}
fn run_tool(
&self,
tool: Arc<dyn AnyAgentTool>,
tool_use: LanguageModelToolUse,
event_stream: AgentResponseEventStream,
cx: &mut Context<Self>,
) -> Task<Result<String>> {
let needs_authorization = tool.needs_authorization(tool_use.input.clone(), cx);
cx.spawn(async move |_this, cx| {
if needs_authorization? {
event_stream.authorize_tool_call(&tool_use).await?;
}
event_stream.send_tool_call_update(
&tool_use.id,
acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
},
);
cx.update(|cx| tool.run(tool_use.input, cx))?.await
})
}
fn handle_tool_use_json_parse_error_event(
&mut self,
tool_use_id: LanguageModelToolUseId,
@ -575,7 +571,7 @@ impl Thread {
log::debug!("Completion intent: {:?}", completion_intent);
log::debug!("Completion mode: {:?}", self.completion_mode);
let messages = self.build_request_messages(cx);
let messages = self.build_request_messages();
log::info!("Request will include {} messages", messages.len());
let tools: Vec<LanguageModelRequestTool> = self
@ -613,14 +609,13 @@ impl Thread {
request
}
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
log::trace!(
"Building request messages from {} thread messages",
self.messages.len()
);
let messages = self
.build_system_message(cx)
let messages = Some(self.build_system_message())
.iter()
.chain(self.messages.iter())
.map(|message| {
@ -674,6 +669,10 @@ where
schemars::schema_for!(Self::Input)
}
/// Returns true if the tool needs the users's authorization
/// before running.
fn needs_authorization(&self, input: Self::Input, cx: &App) -> bool;
/// Runs the tool with the provided input.
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
@ -688,6 +687,7 @@ pub trait AnyAgentTool {
fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result<bool>;
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
}
@ -707,6 +707,14 @@ where
Ok(serde_json::to_value(self.0.input_schema(format))?)
}
fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result<bool> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
match parsed_input {
Ok(input) => Ok(self.0.needs_authorization(input, cx)),
Err(error) => Err(anyhow!(error)),
}
}
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
match parsed_input {
@ -716,16 +724,100 @@ where
}
}
fn to_acp_stop_reason(reason: StopReason) -> Option<acp::StopReason> {
match reason {
StopReason::EndTurn => Some(acp::StopReason::EndTurn),
StopReason::MaxTokens => Some(acp::StopReason::MaxTokens),
StopReason::Refusal => Some(acp::StopReason::Refusal),
StopReason::ToolUse => None,
#[derive(Clone)]
struct AgentResponseEventStream(
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
);
impl AgentResponseEventStream {
fn send_text(&self, text: &str) {
self.0
.unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
.ok();
}
fn send_thinking(&self, text: &str) {
self.0
.unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
.ok();
}
fn authorize_tool_call(
&self,
tool_use: &LanguageModelToolUse,
) -> impl use<> + Future<Output = Result<()>> {
let (response_tx, response_rx) = oneshot::channel();
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
ToolCallAuthorization {
tool_call: acp::ToolCall {
id: acp::ToolCallId(tool_use.id.to_string().into()),
title: tool_use.name.to_string(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(tool_use.input.clone()),
},
options: vec![
acp::PermissionOption {
id: acp::PermissionOptionId("always_allow".into()),
name: "Always Allow".into(),
kind: acp::PermissionOptionKind::AllowAlways,
},
acp::PermissionOption {
id: acp::PermissionOptionId("allow".into()),
name: "Allow".into(),
kind: acp::PermissionOptionKind::AllowOnce,
},
acp::PermissionOption {
id: acp::PermissionOptionId("deny".into()),
name: "Deny".into(),
kind: acp::PermissionOptionKind::RejectOnce,
},
],
response: response_tx,
},
)))
.ok();
async move {
match response_rx.await?.0.as_ref() {
"allow" | "always_allow" => Ok(()),
_ => Err(anyhow!("Permission to run tool denied by user")),
}
}
}
fn to_acp_tool_call_update(tool_result: &LanguageModelToolResult) -> acp::ToolCallUpdate {
fn send_tool_call(&self, tool_use: &LanguageModelToolUse) {
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall {
id: acp::ToolCallId(tool_use.id.to_string().into()),
title: tool_use.name.to_string(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::Pending,
content: vec![],
locations: vec![],
raw_input: Some(tool_use.input.clone()),
})))
.ok();
}
fn send_tool_call_update(
&self,
tool_use_id: &LanguageModelToolUseId,
fields: acp::ToolCallUpdateFields,
) {
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.to_string().into()),
fields,
},
)))
.ok();
}
fn send_tool_call_result(&self, tool_result: &LanguageModelToolResult) {
let status = if tool_result.is_error {
acp::ToolCallStatus::Failed
} else {
@ -743,6 +835,8 @@ fn to_acp_tool_call_update(tool_result: &LanguageModelToolResult) -> acp::ToolCa
}
}
};
self.0
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()),
fields: acp::ToolCallUpdateFields {
@ -750,5 +844,33 @@ fn to_acp_tool_call_update(tool_result: &LanguageModelToolResult) -> acp::ToolCa
content: Some(vec![content]),
..Default::default()
},
},
)))
.ok();
}
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
.ok();
}
StopReason::MaxTokens => {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
.ok();
}
StopReason::Refusal => {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
.ok();
}
StopReason::ToolUse => {}
}
}
fn send_error(&self, error: LanguageModelCompletionError) {
self.0.unbounded_send(Err(error)).ok();
}
}

View file

@ -46,6 +46,10 @@ impl AgentTool for GlobTool {
.into()
}
fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool {
false
}
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
let path_matcher = match PathMatcher::new([&input.glob]) {
Ok(matcher) => matcher,

View file

@ -135,7 +135,7 @@ impl acp_old::Client for OldAcpClientDelegate {
let response = cx
.update(|cx| {
self.thread.borrow().update(cx, |thread, cx| {
thread.request_tool_call_permission(tool_call, acp_options, cx)
thread.request_tool_call_authorization(tool_call, acp_options, cx)
})
})?
.context("Failed to update thread")?

View file

@ -210,7 +210,7 @@ impl acp::Client for ClientDelegate {
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
})?;
let result = rx.await;

View file

@ -153,7 +153,7 @@ impl McpServerTool for PermissionTool {
let chosen_option = thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(
thread.request_tool_call_authorization(
claude_tool.as_acp(tool_call_id),
vec![
acp::PermissionOption {

View file

@ -3147,7 +3147,7 @@ mod tests {
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_permission(
thread.request_tool_call_authorization(
tool_call.clone(),
options.clone(),
cx,

View file

@ -36,13 +36,12 @@ use crate::delete_path_tool::DeletePathTool;
use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_file_tool::EditFileTool;
use crate::fetch_tool::FetchTool;
use crate::find_path_tool::FindPathTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::thinking_tool::ThinkingTool;
pub use edit_file_tool::{EditFileMode, EditFileToolInput};
pub use find_path_tool::FindPathToolInput;
pub use find_path_tool::*;
pub use grep_tool::{GrepTool, GrepToolInput};
pub use open_tool::OpenTool;
pub use project_notifications_tool::ProjectNotificationsTool;

View file

@ -18,7 +18,7 @@ use util::{ResultExt, get_system_shell};
use crate::UserPromptId;
#[derive(Debug, Clone, Serialize)]
#[derive(Default, Debug, Clone, Serialize)]
pub struct ProjectContext {
pub worktrees: Vec<WorktreeContext>,
/// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this.
@ -71,14 +71,14 @@ pub struct UserRulesContext {
pub contents: String,
}
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Eq, PartialEq, Serialize)]
pub struct WorktreeContext {
pub root_name: String,
pub abs_path: Arc<Path>,
pub rules_file: Option<RulesFileContext>,
}
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Eq, PartialEq, Serialize)]
pub struct RulesFileContext {
pub path_in_worktree: Arc<Path>,
pub text: String,