diff --git a/Cargo.lock b/Cargo.lock index b74928e05d..e6a0b6c75f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,9 +138,9 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.20" +version = "0.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12dbfec3d27680337ed9d3064eecafe97acf0b0f190148bb4e29d96707c9e403" +checksum = "b7ae3c22c23b64a5c3b7fc8a86fcc7c494e989bd2cd66fdce14a58cfc8078381" dependencies = [ "anyhow", "futures 0.3.31", @@ -159,6 +159,7 @@ dependencies = [ "agent-client-protocol", "agent_servers", "anyhow", + "assistant_tool", "client", "clock", "cloud_llm_client", @@ -171,10 +172,12 @@ dependencies = [ "gpui_tokio", "handlebars 4.5.0", "indoc", + "language", "language_model", "language_models", "log", "project", + "prompt_store", "reqwest_client", "rust-embed", "schemars", @@ -185,6 +188,7 @@ dependencies = [ "ui", "util", "uuid", + "watch", "workspace-hack", "worktree", ] diff --git a/Cargo.toml b/Cargo.toml index 86f1b8b0a3..6bff713aaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.20" +agent-client-protocol = "0.0.21" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index be9952fd55..1671003023 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -902,7 +902,7 @@ impl AcpThread { }); } - pub fn request_tool_call_permission( + pub fn request_tool_call_authorization( &mut self, tool_call: acp::ToolCall, options: Vec, diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 74aa2993dd..21a043fd98 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -16,6 +16,7 @@ acp_thread.workspace = true agent-client-protocol.workspace = true agent_servers.workspace = true anyhow.workspace = true +assistant_tool.workspace = true cloud_llm_client.workspace = true collections.workspace = true fs.workspace = true @@ -27,6 +28,7 @@ language_model.workspace = true language_models.workspace = true log.workspace = true project.workspace = true +prompt_store.workspace = true rust-embed.workspace = true schemars.workspace = true serde.workspace = true @@ -36,6 +38,7 @@ smol.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true +watch.workspace = true worktree.workspace = true workspace-hack.workspace = true @@ -47,6 +50,7 @@ env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } gpui_tokio.workspace = true +language = { workspace = true, "features" = ["test-support"] } language_model = { workspace = true, "features" = ["test-support"] } project = { workspace = true, "features" = ["test-support"] } reqwest_client.workspace = true diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 305a31fc98..5c0acb3fb1 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,16 +1,39 @@ +use crate::ToolCallAuthorization; +use crate::{templates::Templates, AgentResponseEvent, Thread}; use acp_thread::ModelSelector; use agent_client_protocol as acp; -use anyhow::{anyhow, Result}; -use futures::StreamExt; -use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity}; +use anyhow::{anyhow, Context as _, Result}; +use futures::{future, StreamExt}; +use gpui::{ + App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, +}; use language_model::{LanguageModel, LanguageModelRegistry}; -use project::Project; +use project::{Project, ProjectItem, ProjectPath, Worktree}; +use prompt_store::{ + ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, +}; +use std::cell::RefCell; use std::collections::HashMap; use std::path::Path; use std::rc::Rc; use std::sync::Arc; +use util::ResultExt; -use crate::{templates::Templates, AgentResponseEvent, Thread}; +const RULES_FILE_NAMES: [&'static str; 9] = [ + ".rules", + ".cursorrules", + ".windsurfrules", + ".clinerules", + ".github/copilot-instructions.md", + "CLAUDE.md", + "AGENT.md", + "AGENTS.md", + "GEMINI.md", +]; + +pub struct RulesLoadingError { + pub message: SharedString, +} /// Holds both the internal Thread and the AcpThread for a session struct Session { @@ -24,17 +47,247 @@ struct Session { pub struct NativeAgent { /// Session ID -> Session mapping sessions: HashMap, + /// Shared project context for all threads + project_context: Rc>, + project_context_needs_refresh: watch::Sender<()>, + _maintain_project_context: Task>, /// Shared templates for all threads templates: Arc, + project: Entity, + prompt_store: Option>, + _subscriptions: Vec, } impl NativeAgent { - pub fn new(templates: Arc) -> Self { + pub async fn new( + project: Entity, + templates: Arc, + prompt_store: Option>, + cx: &mut AsyncApp, + ) -> Result> { log::info!("Creating new NativeAgent"); - Self { - sessions: HashMap::new(), - templates, + + let project_context = cx + .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))? + .await; + + cx.new(|cx| { + let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; + if let Some(prompt_store) = prompt_store.as_ref() { + subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event)) + } + + let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = + watch::channel(()); + Self { + sessions: HashMap::new(), + project_context: Rc::new(RefCell::new(project_context)), + project_context_needs_refresh: project_context_needs_refresh_tx, + _maintain_project_context: cx.spawn(async move |this, cx| { + Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await + }), + templates, + project, + prompt_store, + _subscriptions: subscriptions, + } + }) + } + + async fn maintain_project_context( + this: WeakEntity, + mut needs_refresh: watch::Receiver<()>, + cx: &mut AsyncApp, + ) -> Result<()> { + while needs_refresh.changed().await.is_ok() { + let project_context = this + .update(cx, |this, cx| { + Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx) + })? + .await; + this.update(cx, |this, _| this.project_context.replace(project_context))?; } + + Ok(()) + } + + fn build_project_context( + project: &Entity, + prompt_store: Option<&Entity>, + cx: &mut App, + ) -> Task { + let worktrees = project.read(cx).visible_worktrees(cx).collect::>(); + let worktree_tasks = worktrees + .into_iter() + .map(|worktree| { + Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx) + }) + .collect::>(); + let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() { + prompt_store.read_with(cx, |prompt_store, cx| { + let prompts = prompt_store.default_prompt_metadata(); + let load_tasks = prompts.into_iter().map(|prompt_metadata| { + let contents = prompt_store.load(prompt_metadata.id, cx); + async move { (contents.await, prompt_metadata) } + }); + cx.background_spawn(future::join_all(load_tasks)) + }) + } else { + Task::ready(vec![]) + }; + + cx.spawn(async move |_cx| { + let (worktrees, default_user_rules) = + future::join(future::join_all(worktree_tasks), default_user_rules_task).await; + + let worktrees = worktrees + .into_iter() + .map(|(worktree, _rules_error)| { + // TODO: show error message + // if let Some(rules_error) = rules_error { + // this.update(cx, |_, cx| cx.emit(rules_error)).ok(); + // } + worktree + }) + .collect::>(); + + let default_user_rules = default_user_rules + .into_iter() + .flat_map(|(contents, prompt_metadata)| match contents { + Ok(contents) => Some(UserRulesContext { + uuid: match prompt_metadata.id { + PromptId::User { uuid } => uuid, + PromptId::EditWorkflow => return None, + }, + title: prompt_metadata.title.map(|title| title.to_string()), + contents, + }), + Err(_err) => { + // TODO: show error message + // this.update(cx, |_, cx| { + // cx.emit(RulesLoadingError { + // message: format!("{err:?}").into(), + // }); + // }) + // .ok(); + None + } + }) + .collect::>(); + + ProjectContext::new(worktrees, default_user_rules) + }) + } + + fn load_worktree_info_for_system_prompt( + worktree: Entity, + project: Entity, + cx: &mut App, + ) -> Task<(WorktreeContext, Option)> { + let tree = worktree.read(cx); + let root_name = tree.root_name().into(); + let abs_path = tree.abs_path(); + + let mut context = WorktreeContext { + root_name, + abs_path, + rules_file: None, + }; + + let rules_task = Self::load_worktree_rules_file(worktree, project, cx); + let Some(rules_task) = rules_task else { + return Task::ready((context, None)); + }; + + cx.spawn(async move |_| { + let (rules_file, rules_file_error) = match rules_task.await { + Ok(rules_file) => (Some(rules_file), None), + Err(err) => ( + None, + Some(RulesLoadingError { + message: format!("{err}").into(), + }), + ), + }; + context.rules_file = rules_file; + (context, rules_file_error) + }) + } + + fn load_worktree_rules_file( + worktree: Entity, + project: Entity, + cx: &mut App, + ) -> Option>> { + let worktree = worktree.read(cx); + let worktree_id = worktree.id(); + let selected_rules_file = RULES_FILE_NAMES + .into_iter() + .filter_map(|name| { + worktree + .entry_for_path(name) + .filter(|entry| entry.is_file()) + .map(|entry| entry.path.clone()) + }) + .next(); + + // Note that Cline supports `.clinerules` being a directory, but that is not currently + // supported. This doesn't seem to occur often in GitHub repositories. + selected_rules_file.map(|path_in_worktree| { + let project_path = ProjectPath { + worktree_id, + path: path_in_worktree.clone(), + }; + let buffer_task = + project.update(cx, |project, cx| project.open_buffer(project_path, cx)); + let rope_task = cx.spawn(async move |cx| { + buffer_task.await?.read_with(cx, |buffer, cx| { + let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?; + anyhow::Ok((project_entry_id, buffer.as_rope().clone())) + })? + }); + // Build a string from the rope on a background thread. + cx.background_spawn(async move { + let (project_entry_id, rope) = rope_task.await?; + anyhow::Ok(RulesFileContext { + path_in_worktree, + text: rope.to_string().trim().to_string(), + project_entry_id: project_entry_id.to_usize(), + }) + }) + }) + } + + fn handle_project_event( + &mut self, + _project: Entity, + event: &project::Event, + _cx: &mut Context, + ) { + match event { + project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { + self.project_context_needs_refresh.send(()).ok(); + } + project::Event::WorktreeUpdatedEntries(_, items) => { + if items.iter().any(|(path, _, _)| { + RULES_FILE_NAMES + .iter() + .any(|name| path.as_ref() == Path::new(name)) + }) { + self.project_context_needs_refresh.send(()).ok(); + } + } + _ => {} + } + } + + fn handle_prompts_updated_event( + &mut self, + _prompt_store: Entity, + _event: &prompt_store::PromptsUpdatedEvent, + _cx: &mut Context, + ) { + self.project_context_needs_refresh.send(()).ok(); } } @@ -120,8 +373,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection { cx.spawn(async move |cx| { log::debug!("Starting thread creation in async context"); + + // Generate session ID + let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + log::info!("Created session with ID: {}", session_id); + + // Create AcpThread + let acp_thread = cx.update(|cx| { + cx.new(|cx| { + acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx) + }) + })?; + let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?; + // Create Thread - let (session_id, thread) = agent.update( + let thread = agent.update( cx, |agent, cx: &mut gpui::Context| -> Result<_> { // Fetch default model from registry settings @@ -146,22 +412,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { anyhow!("No default model configured. Please configure a default model in settings.") })?; - let thread = cx.new(|_| Thread::new(project.clone(), agent.templates.clone(), default_model)); - - // Generate session ID - let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); - log::info!("Created session with ID: {}", session_id); - Ok((session_id, thread)) + let thread = cx.new(|_| Thread::new(project, agent.project_context.clone(), action_log, agent.templates.clone(), default_model)); + Ok(thread) }, )??; - // Create AcpThread - let acp_thread = cx.update(|cx| { - cx.new(|cx| { - acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx) - }) - })?; - // Store the session agent.update(cx, |agent, cx| { agent.sessions.insert( @@ -264,6 +519,28 @@ impl acp_thread::AgentConnection for NativeAgentConnection { ) })??; } + AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { + tool_call, + options, + response, + }) => { + let recv = acp_thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization(tool_call, options, cx) + })?; + cx.background_spawn(async move { + if let Some(option) = recv + .await + .context("authorization sender was dropped") + .log_err() + { + response + .send(option) + .map(|_| anyhow!("authorization receiver was dropped")) + .log_err(); + } + }) + .detach(); + } AgentResponseEvent::ToolCall(tool_call) => { acp_thread.update(cx, |thread, cx| { thread.handle_session_update( @@ -343,3 +620,77 @@ fn convert_prompt_to_message(blocks: Vec) -> String { message } + +#[cfg(test)] +mod tests { + use super::*; + use fs::FakeFs; + use gpui::TestAppContext; + use serde_json::json; + use settings::SettingsStore; + + #[gpui::test] + async fn test_maintaining_project_context(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": {} + }), + ) + .await; + let project = Project::test(fs.clone(), [], cx).await; + let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async()) + .await + .unwrap(); + agent.read_with(cx, |agent, _| { + assert_eq!(agent.project_context.borrow().worktrees, vec![]) + }); + + let worktree = project + .update(cx, |project, cx| project.create_worktree("/a", true, cx)) + .await + .unwrap(); + cx.run_until_parked(); + agent.read_with(cx, |agent, _| { + assert_eq!( + agent.project_context.borrow().worktrees, + vec![WorktreeContext { + root_name: "a".into(), + abs_path: Path::new("/a").into(), + rules_file: None + }] + ) + }); + + // Creating `/a/.rules` updates the project context. + fs.insert_file("/a/.rules", Vec::new()).await; + cx.run_until_parked(); + agent.read_with(cx, |agent, cx| { + let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap(); + assert_eq!( + agent.project_context.borrow().worktrees, + vec![WorktreeContext { + root_name: "a".into(), + abs_path: Path::new("/a").into(), + rules_file: Some(RulesFileContext { + path_in_worktree: Path::new(".rules").into(), + text: "".into(), + project_entry_id: rules_entry.id.to_usize() + }) + }] + ) + }); + } + + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + } +} diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index aa665fe313..d759f63d89 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -1,6 +1,5 @@ mod agent; mod native_agent_server; -mod prompts; mod templates; mod thread; mod tools; diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent2/src/native_agent_server.rs index aafe70a8a2..dd0188b548 100644 --- a/crates/agent2/src/native_agent_server.rs +++ b/crates/agent2/src/native_agent_server.rs @@ -3,8 +3,9 @@ use std::rc::Rc; use agent_servers::AgentServer; use anyhow::Result; -use gpui::{App, AppContext, Entity, Task}; +use gpui::{App, Entity, Task}; use project::Project; +use prompt_store::PromptStore; use crate::{templates::Templates, NativeAgent, NativeAgentConnection}; @@ -32,21 +33,22 @@ impl AgentServer for NativeAgentServer { fn connect( &self, _root_dir: &Path, - _project: &Entity, + project: &Entity, cx: &mut App, ) -> Task>> { log::info!( "NativeAgentServer::connect called for path: {:?}", _root_dir ); + let project = project.clone(); + let prompt_store = PromptStore::global(cx); cx.spawn(async move |cx| { log::debug!("Creating templates for native agent"); - // Create templates (you might want to load these from files or resources) let templates = Templates::new(); + let prompt_store = prompt_store.await?; - // Create the native agent log::debug!("Creating native agent entity"); - let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?; + let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?; // Create the connection wrapper let connection = NativeAgentConnection(agent); diff --git a/crates/agent2/src/prompts.rs b/crates/agent2/src/prompts.rs deleted file mode 100644 index 28507f4968..0000000000 --- a/crates/agent2/src/prompts.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::{ - templates::{BaseTemplate, Template, Templates, WorktreeData}, - thread::Prompt, -}; -use anyhow::Result; -use gpui::{App, Entity}; -use project::Project; - -pub struct BasePrompt { - project: Entity, -} - -impl BasePrompt { - pub fn new(project: Entity) -> Self { - Self { project } - } -} - -impl Prompt for BasePrompt { - fn render(&self, templates: &Templates, cx: &App) -> Result { - BaseTemplate { - os: std::env::consts::OS.to_string(), - shell: util::get_system_shell(), - worktrees: self - .project - .read(cx) - .worktrees(cx) - .map(|worktree| WorktreeData { - root_name: worktree.read(cx).root_name().to_string(), - }) - .collect(), - } - .render(templates) - } -} diff --git a/crates/agent2/src/templates.rs b/crates/agent2/src/templates.rs index 04569369be..7d51a626fc 100644 --- a/crates/agent2/src/templates.rs +++ b/crates/agent2/src/templates.rs @@ -1,9 +1,9 @@ -use std::sync::Arc; - use anyhow::Result; +use gpui::SharedString; use handlebars::Handlebars; use rust_embed::RustEmbed; use serde::Serialize; +use std::sync::Arc; #[derive(RustEmbed)] #[folder = "src/templates"] @@ -15,6 +15,8 @@ pub struct Templates(Handlebars<'static>); impl Templates { pub fn new() -> Arc { let mut handlebars = Handlebars::new(); + handlebars.set_strict_mode(true); + handlebars.register_helper("contains", Box::new(contains)); handlebars.register_embed_templates::().unwrap(); Arc::new(Self(handlebars)) } @@ -31,22 +33,6 @@ pub trait Template: Sized { } } -#[derive(Serialize)] -pub struct BaseTemplate { - pub os: String, - pub shell: String, - pub worktrees: Vec, -} - -impl Template for BaseTemplate { - const TEMPLATE_NAME: &'static str = "base.hbs"; -} - -#[derive(Serialize)] -pub struct WorktreeData { - pub root_name: String, -} - #[derive(Serialize)] pub struct GlobTemplate { pub project_roots: String, @@ -55,3 +41,56 @@ pub struct GlobTemplate { impl Template for GlobTemplate { const TEMPLATE_NAME: &'static str = "glob.hbs"; } + +#[derive(Serialize)] +pub struct SystemPromptTemplate<'a> { + #[serde(flatten)] + pub project: &'a prompt_store::ProjectContext, + pub available_tools: Vec, +} + +impl Template for SystemPromptTemplate<'_> { + const TEMPLATE_NAME: &'static str = "system_prompt.hbs"; +} + +/// Handlebars helper for checking if an item is in a list +fn contains( + h: &handlebars::Helper, + _: &handlebars::Handlebars, + _: &handlebars::Context, + _: &mut handlebars::RenderContext, + out: &mut dyn handlebars::Output, +) -> handlebars::HelperResult { + let list = h + .param(0) + .and_then(|v| v.value().as_array()) + .ok_or_else(|| { + handlebars::RenderError::new("contains: missing or invalid list parameter") + })?; + let query = h.param(1).map(|v| v.value()).ok_or_else(|| { + handlebars::RenderError::new("contains: missing or invalid query parameter") + })?; + + if list.contains(&query) { + out.write("true")?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_system_prompt_template() { + let project = prompt_store::ProjectContext::default(); + let template = SystemPromptTemplate { + project: &project, + available_tools: vec!["echo".into()], + }; + let templates = Templates::new(); + let rendered = template.render(&templates).unwrap(); + assert!(rendered.contains("## Fixing Diagnostics")); + } +} diff --git a/crates/agent2/src/templates/base.hbs b/crates/agent2/src/templates/base.hbs deleted file mode 100644 index 7eef231e32..0000000000 --- a/crates/agent2/src/templates/base.hbs +++ /dev/null @@ -1,56 +0,0 @@ -You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. - -## Communication - -1. Be conversational but professional. -2. Refer to the USER in the second person and yourself in the first person. -3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. -4. NEVER lie or make things up. -5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing. - -## Tool Use - -1. Make sure to adhere to the tools schema. -2. Provide every required argument. -3. DO NOT use tools to access items that are already available in the context section. -4. Use only the tools that are currently available. -5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off. - -## Searching and Reading - -If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions. - -If appropriate, use tool calls to explore the current project, which contains the following root directories: - -{{#each worktrees}} -- `{{root_name}}` -{{/each}} - -- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above. -- When looking for symbols in the project, prefer the `grep` tool. -- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project. -- Bias towards not asking the user for help if you can find the answer yourself. - -## Fixing Diagnostics - -1. Make 1-2 attempts at fixing diagnostics, then defer to the user. -2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem. - -## Debugging - -When debugging, only make code changes if you are certain that you can solve the problem. -Otherwise, follow debugging best practices: -1. Address the root cause instead of the symptoms. -2. Add descriptive logging statements and error messages to track variable and code state. -3. Add test functions and statements to isolate the problem. - -## Calling External APIs - -1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission. -2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data. -3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed) - -## System Information - -Operating System: {{os}} -Default Shell: {{shell}} diff --git a/crates/agent2/src/templates/system_prompt.hbs b/crates/agent2/src/templates/system_prompt.hbs new file mode 100644 index 0000000000..a9f67460d8 --- /dev/null +++ b/crates/agent2/src/templates/system_prompt.hbs @@ -0,0 +1,178 @@ +You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. + +## Communication + +1. Be conversational but professional. +2. Refer to the user in the second person and yourself in the first person. +3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. +4. NEVER lie or make things up. +5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing. + +{{#if (gt (len available_tools) 0)}} +## Tool Use + +1. Make sure to adhere to the tools schema. +2. Provide every required argument. +3. DO NOT use tools to access items that are already available in the context section. +4. Use only the tools that are currently available. +5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off. +6. NEVER run commands that don't terminate on their own such as web servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers. +7. Avoid HTML entity escaping - use plain characters instead. + +## Searching and Reading + +If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions. + +If appropriate, use tool calls to explore the current project, which contains the following root directories: + +{{#each worktrees}} +- `{{abs_path}}` +{{/each}} + +- Bias towards not asking the user for help if you can find the answer yourself. +- When providing paths to tools, the path should always start with the name of a project root directory listed above. +- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path! +{{# if (contains available_tools 'grep') }} +- When looking for symbols in the project, prefer the `grep` tool. +- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project. +- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file. +{{/if}} +{{else}} +You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you). + +As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally. + +The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response. +{{/if}} + +## Code Block Formatting + +Whenever you mention a code block, you MUST use ONLY use the following format: +```path/to/Something.blah#L123-456 +(code goes here) +``` +The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah +is a path in the project. (If there is no valid path in the project, then you can use +/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser +does not understand the more common ```language syntax, or bare ``` blocks. It only +understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again. +Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP! +You have made a mistake. You can only ever put paths after triple backticks! + +Based on all the information I've gathered, here's a summary of how this system works: +1. The README file is loaded into the system. +2. The system finds the first two headers, including everything in between. In this case, that would be: +```path/to/README.md#L8-12 +# First Header +This is the info under the first header. +## Sub-header +``` +3. Then the system finds the last header in the README: +```path/to/README.md#L27-29 +## Last Header +This is the last header in the README. +``` +4. Finally, it passes this information on to the next process. + + +In Markdown, hash marks signify headings. For example: +```/dev/null/example.md#L1-3 +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +Here are examples of ways you must never render code blocks: + +In Markdown, hash marks signify headings. For example: +``` +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +This example is unacceptable because it does not include the path. + +In Markdown, hash marks signify headings. For example: +```markdown +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +This example is unacceptable because it has the language instead of the path. + +In Markdown, hash marks signify headings. For example: + # Level 1 heading + ## Level 2 heading + ### Level 3 heading + +This example is unacceptable because it uses indentation to mark the code block +instead of backticks with a path. + +In Markdown, hash marks signify headings. For example: +```markdown +/dev/null/example.md#L1-3 +# Level 1 heading +## Level 2 heading +### Level 3 heading +``` + +This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks. + +{{#if (gt (len available_tools) 0)}} +## Fixing Diagnostics + +1. Make 1-2 attempts at fixing diagnostics, then defer to the user. +2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem. + +## Debugging + +When debugging, only make code changes if you are certain that you can solve the problem. +Otherwise, follow debugging best practices: +1. Address the root cause instead of the symptoms. +2. Add descriptive logging statements and error messages to track variable and code state. +3. Add test functions and statements to isolate the problem. + +{{/if}} +## Calling External APIs + +1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission. +2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data. +3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed) + +## System Information + +Operating System: {{os}} +Default Shell: {{shell}} + +{{#if (or has_rules has_user_rules)}} +## User's Custom Instructions + +The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if (gt (len available_tools) 0)}} without interfering with the tool use guidelines{{/if}}. + +{{#if has_rules}} +There are project rules that apply to these root directories: +{{#each worktrees}} +{{#if rules_file}} +`{{root_name}}/{{rules_file.path_in_worktree}}`: +`````` +{{{rules_file.text}}} +`````` +{{/if}} +{{/each}} +{{/if}} + +{{#if has_user_rules}} +The user has specified the following rules that should be applied: +{{#each user_rules}} + +{{#if title}} +Rules title: {{title}} +{{/if}} +`````` +{{contents}}} +`````` +{{/each}} +{{/if}} +{{/if}} diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 330d04b60c..b13b1cbe1a 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,30 +1,34 @@ use super::*; use crate::templates::Templates; use acp_thread::AgentConnection; -use agent_client_protocol as acp; +use agent_client_protocol::{self as acp}; +use anyhow::Result; +use assistant_tool::ActionLog; use client::{Client, UserStore}; use fs::FakeFs; +use futures::channel::mpsc::UnboundedReceiver; use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext}; use indoc::indoc; use language_model::{ fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError, - LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent, - StopReason, + LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult, + LanguageModelToolUse, MessageContent, Role, StopReason, }; use project::Project; +use prompt_store::ProjectContext; use reqwest_client::ReqwestClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; use smol::stream::StreamExt; -use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; +use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration}; use util::path; mod test_tools; use test_tools::*; #[gpui::test] -#[ignore = "temporarily disabled until it can be run on CI"] +#[ignore = "can't run on CI yet"] async fn test_echo(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; @@ -44,7 +48,7 @@ async fn test_echo(cx: &mut TestAppContext) { } #[gpui::test] -#[ignore = "temporarily disabled until it can be run on CI"] +#[ignore = "can't run on CI yet"] async fn test_thinking(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await; @@ -77,7 +81,46 @@ async fn test_thinking(cx: &mut TestAppContext) { } #[gpui::test] -#[ignore = "temporarily disabled until it can be run on CI"] +async fn test_system_prompt(cx: &mut TestAppContext) { + let ThreadTest { + model, + thread, + project_context, + .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + project_context.borrow_mut().shell = "test-shell".into(); + thread.update(cx, |thread, _| thread.add_tool(EchoTool)); + thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx)); + cx.run_until_parked(); + let mut pending_completions = fake_model.pending_completions(); + assert_eq!( + pending_completions.len(), + 1, + "unexpected pending completions: {:?}", + pending_completions + ); + + let pending_completion = pending_completions.pop().unwrap(); + assert_eq!(pending_completion.messages[0].role, Role::System); + + let system_message = &pending_completion.messages[0]; + let system_prompt = system_message.content[0].to_str().unwrap(); + assert!( + system_prompt.contains("test-shell"), + "unexpected system message: {:?}", + system_message + ); + assert!( + system_prompt.contains("## Fixing Diagnostics"), + "unexpected system message: {:?}", + system_message + ); +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] async fn test_basic_tool_calls(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; @@ -127,7 +170,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { } #[gpui::test] -#[ignore = "temporarily disabled until it can be run on CI"] +#[ignore = "can't run on CI yet"] async fn test_streaming_tool_calls(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; @@ -175,7 +218,104 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { } #[gpui::test] -#[ignore = "temporarily disabled until it can be run on CI"] +async fn test_tool_authorization(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| { + thread.add_tool(ToolRequiringPermission); + thread.send(model.clone(), "abc", cx) + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_1".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_2".into(), + name: ToolRequiringPermission.name().into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + let tool_call_auth_1 = next_tool_call_authorization(&mut events).await; + let tool_call_auth_2 = next_tool_call_authorization(&mut events).await; + + // Approve the first + tool_call_auth_1 + .response + .send(tool_call_auth_1.options[1].id.clone()) + .unwrap(); + cx.run_until_parked(); + + // Reject the second + tool_call_auth_2 + .response + .send(tool_call_auth_1.options[2].id.clone()) + .unwrap(); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + let message = completion.messages.last().unwrap(); + assert_eq!( + message.content, + vec![ + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(), + tool_name: tool_call_auth_1.tool_call.title.into(), + is_error: false, + content: "Allowed".into(), + output: None + }), + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), + tool_name: tool_call_auth_2.tool_call.title.into(), + is_error: true, + content: "Permission to run tool denied by user".into(), + output: None + }) + ] + ); +} + +async fn next_tool_call_authorization( + events: &mut UnboundedReceiver>, +) -> ToolCallAuthorization { + loop { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { + let permission_kinds = tool_call_authorization + .options + .iter() + .map(|o| o.kind) + .collect::>(); + assert_eq!( + permission_kinds, + vec![ + acp::PermissionOptionKind::AllowAlways, + acp::PermissionOptionKind::AllowOnce, + acp::PermissionOptionKind::RejectOnce, + ] + ); + return tool_call_authorization; + } + } +} + +#[gpui::test] +#[ignore = "can't run on CI yet"] async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; @@ -214,7 +354,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { } #[gpui::test] -#[ignore = "temporarily disabled until it can be run on CI"] +#[ignore = "can't run on CI yet"] async fn test_cancellation(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await; @@ -281,12 +421,10 @@ async fn test_cancellation(cx: &mut TestAppContext) { #[gpui::test] async fn test_refusal(cx: &mut TestAppContext) { - let fake_model = Arc::new(FakeLanguageModel::default()); - let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await; + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); - let events = thread.update(cx, |thread, cx| { - thread.send(fake_model.clone(), "Hello", cx) - }); + let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx)); cx.run_until_parked(); thread.read_with(cx, |thread, _| { assert_eq!( @@ -343,8 +481,16 @@ async fn test_agent_connection(cx: &mut TestAppContext) { }); cx.executor().forbid_parking(); + // Create a project for new_thread + let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); + fake_fs.insert_tree(path!("/test"), json!({})).await; + let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + let cwd = Path::new("/test"); + // Create agent and connection - let agent = cx.new(|_| NativeAgent::new(templates.clone())); + let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async()) + .await + .unwrap(); let connection = NativeAgentConnection(agent.clone()); // Test model_selector returns Some @@ -366,12 +512,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { assert!(!listed_models.is_empty(), "should have at least one model"); assert_eq!(listed_models[0].id().0, "fake"); - // Create a project for new_thread - let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); - let project = Project::test(fake_fs, [Path::new("/test")], cx).await; - // Create a thread using new_thread - let cwd = Path::new("/test"); let connection_rc = Rc::new(connection.clone()); let acp_thread = cx .update(|cx| { @@ -457,12 +598,13 @@ fn stop_events( struct ThreadTest { model: Arc, thread: Entity, + project_context: Rc>, } enum TestModel { Sonnet4, Sonnet4Thinking, - Fake(Arc), + Fake, } impl TestModel { @@ -470,7 +612,7 @@ impl TestModel { match self { TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()), TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()), - TestModel::Fake(fake_model) => fake_model.id(), + TestModel::Fake => unreachable!(), } } } @@ -499,8 +641,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); - if let TestModel::Fake(model) = model { - Task::ready(model as Arc<_>) + if let TestModel::Fake = model { + Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>) } else { let model_id = model.id(); let models = LanguageModelRegistry::read_global(cx); @@ -520,9 +662,22 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { }) .await; - let thread = cx.new(|_| Thread::new(project, templates, model.clone())); - - ThreadTest { model, thread } + let project_context = Rc::new(RefCell::new(ProjectContext::default())); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let thread = cx.new(|_| { + Thread::new( + project, + project_context.clone(), + action_log, + templates, + model.clone(), + ) + }); + ThreadTest { + model, + thread, + project_context, + } } #[cfg(test)] diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs index 1847a14fee..a066bb982e 100644 --- a/crates/agent2/src/tests/test_tools.rs +++ b/crates/agent2/src/tests/test_tools.rs @@ -19,6 +19,10 @@ impl AgentTool for EchoTool { "echo".into() } + fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { + false + } + fn run(self: Arc, input: Self::Input, _cx: &mut App) -> Task> { Task::ready(Ok(input.text)) } @@ -40,6 +44,10 @@ impl AgentTool for DelayTool { "delay".into() } + fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { + false + } + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> where Self: Sized, @@ -51,6 +59,31 @@ impl AgentTool for DelayTool { } } +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct ToolRequiringPermissionInput {} + +pub struct ToolRequiringPermission; + +impl AgentTool for ToolRequiringPermission { + type Input = ToolRequiringPermissionInput; + + fn name(&self) -> SharedString { + "tool_requiring_permission".into() + } + + fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { + true + } + + fn run(self: Arc, _input: Self::Input, cx: &mut App) -> Task> + where + Self: Sized, + { + cx.foreground_executor() + .spawn(async move { Ok("Allowed".to_string()) }) + } +} + #[derive(JsonSchema, Serialize, Deserialize)] pub struct InfiniteToolInput {} @@ -63,6 +96,10 @@ impl AgentTool for InfiniteTool { "infinite".into() } + fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { + false + } + fn run(self: Arc, _input: Self::Input, cx: &mut App) -> Task> { cx.foreground_executor().spawn(async move { future::pending::<()>().await; @@ -100,6 +137,10 @@ impl AgentTool for WordListTool { "word_list".into() } + fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { + false + } + fn run(self: Arc, _input: Self::Input, _cx: &mut App) -> Task> { Task::ready(Ok("ok".to_string())) } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index af3aa17ea8..9b17d7e37e 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,9 +1,13 @@ -use crate::{prompts::BasePrompt, templates::Templates}; +use crate::templates::{SystemPromptTemplate, Template, Templates}; use agent_client_protocol as acp; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context as _, Result}; +use assistant_tool::ActionLog; use cloud_llm_client::{CompletionIntent, CompletionMode}; use collections::HashMap; -use futures::{channel::mpsc, stream::FuturesUnordered}; +use futures::{ + channel::{mpsc, oneshot}, + stream::FuturesUnordered, +}; use gpui::{App, Context, Entity, ImageFormat, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, @@ -13,10 +17,11 @@ use language_model::{ }; use log; use project::Project; +use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::Deserialize; use smol::stream::StreamExt; -use std::{collections::BTreeMap, fmt::Write, sync::Arc}; +use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc}; use util::{markdown::MarkdownCodeBlock, ResultExt}; #[derive(Debug, Clone)] @@ -97,11 +102,15 @@ pub enum AgentResponseEvent { Thinking(String), ToolCall(acp::ToolCall), ToolCallUpdate(acp::ToolCallUpdate), + ToolCallAuthorization(ToolCallAuthorization), Stop(acp::StopReason), } -pub trait Prompt { - fn render(&self, prompts: &Templates, cx: &App) -> Result; +#[derive(Debug)] +pub struct ToolCallAuthorization { + pub tool_call: acp::ToolCall, + pub options: Vec, + pub response: oneshot::Sender, } pub struct Thread { @@ -112,28 +121,31 @@ pub struct Thread { /// we run tools, report their results. running_turn: Option>, pending_tool_uses: HashMap, - system_prompts: Vec>, tools: BTreeMap>, + project_context: Rc>, templates: Arc, pub selected_model: Arc, - // action_log: Entity, + _action_log: Entity, } impl Thread { pub fn new( - project: Entity, + _project: Entity, + project_context: Rc>, + action_log: Entity, templates: Arc, default_model: Arc, ) -> Self { Self { messages: Vec::new(), completion_mode: CompletionMode::Normal, - system_prompts: vec![Arc::new(BasePrompt::new(project))], running_turn: None, pending_tool_uses: HashMap::default(), tools: BTreeMap::default(), + project_context, templates, selected_model: default_model, + _action_log: action_log, } } @@ -188,6 +200,7 @@ impl Thread { cx.notify(); let (events_tx, events_rx) = mpsc::unbounded::>(); + let event_stream = AgentResponseEventStream(events_tx); let user_message_ix = self.messages.len(); self.messages.push(AgentMessage { @@ -222,12 +235,7 @@ impl Thread { while let Some(event) = events.next().await { match event { Ok(LanguageModelCompletionEvent::Stop(reason)) => { - if let Some(reason) = to_acp_stop_reason(reason) { - events_tx - .unbounded_send(Ok(AgentResponseEvent::Stop(reason))) - .ok(); - } - + event_stream.send_stop(reason); if reason == StopReason::Refusal { thread.update(cx, |thread, _cx| { thread.messages.truncate(user_message_ix); @@ -240,14 +248,16 @@ impl Thread { thread .update(cx, |thread, cx| { tool_uses.extend(thread.handle_streamed_completion_event( - event, &events_tx, cx, + event, + &event_stream, + cx, )); }) .ok(); } Err(error) => { log::error!("Error in completion stream: {:?}", error); - events_tx.unbounded_send(Err(error)).ok(); + event_stream.send_error(error); break; } } @@ -266,11 +276,7 @@ impl Thread { while let Some(tool_result) = tool_uses.next().await { log::info!("Tool finished {:?}", tool_result); - events_tx - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( - to_acp_tool_call_update(&tool_result), - ))) - .ok(); + event_stream.send_tool_call_result(&tool_result); thread .update(cx, |thread, _cx| { thread.pending_tool_uses.remove(&tool_result.tool_use_id); @@ -291,7 +297,7 @@ impl Thread { if let Err(error) = turn_result { log::error!("Turn execution failed: {:?}", error); - events_tx.unbounded_send(Err(error)).ok(); + event_stream.send_error(error); } else { log::info!("Turn execution completed successfully"); } @@ -299,24 +305,20 @@ impl Thread { events_rx } - pub fn build_system_message(&self, cx: &App) -> Option { + pub fn build_system_message(&self) -> AgentMessage { log::debug!("Building system message"); - let mut system_message = AgentMessage { - role: Role::System, - content: Vec::new(), - }; - - for prompt in &self.system_prompts { - if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() { - system_message - .content - .push(MessageContent::Text(rendered_prompt)); - } + let prompt = SystemPromptTemplate { + project: &self.project_context.borrow(), + available_tools: self.tools.keys().cloned().collect(), + } + .render(&self.templates) + .context("failed to build system prompt") + .expect("Invalid template"); + log::debug!("System message built"); + AgentMessage { + role: Role::System, + content: vec![prompt.into()], } - - let result = (!system_message.content.is_empty()).then_some(system_message); - log::debug!("System message built: {}", result.is_some()); - result } /// A helper method that's called on every streamed completion event. @@ -325,7 +327,7 @@ impl Thread { fn handle_streamed_completion_event( &mut self, event: LanguageModelCompletionEvent, - events_tx: &mpsc::UnboundedSender>, + event_stream: &AgentResponseEventStream, cx: &mut Context, ) -> Option> { log::trace!("Handling streamed completion event: {:?}", event); @@ -338,13 +340,13 @@ impl Thread { content: Vec::new(), }); } - Text(new_text) => self.handle_text_event(new_text, events_tx, cx), + Text(new_text) => self.handle_text_event(new_text, event_stream, cx), Thinking { text, signature } => { - self.handle_thinking_event(text, signature, events_tx, cx) + self.handle_thinking_event(text, signature, event_stream, cx) } RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx), ToolUse(tool_use) => { - return self.handle_tool_use_event(tool_use, events_tx, cx); + return self.handle_tool_use_event(tool_use, event_stream, cx); } ToolUseJsonParseError { id, @@ -369,12 +371,10 @@ impl Thread { fn handle_text_event( &mut self, new_text: String, - events_tx: &mpsc::UnboundedSender>, + events_stream: &AgentResponseEventStream, cx: &mut Context, ) { - events_tx - .unbounded_send(Ok(AgentResponseEvent::Text(new_text.clone()))) - .ok(); + events_stream.send_text(&new_text); let last_message = self.last_assistant_message(); if let Some(MessageContent::Text(text)) = last_message.content.last_mut() { @@ -390,12 +390,10 @@ impl Thread { &mut self, new_text: String, new_signature: Option, - events_tx: &mpsc::UnboundedSender>, + event_stream: &AgentResponseEventStream, cx: &mut Context, ) { - events_tx - .unbounded_send(Ok(AgentResponseEvent::Thinking(new_text.clone()))) - .ok(); + event_stream.send_thinking(&new_text); let last_message = self.last_assistant_message(); if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut() @@ -423,7 +421,7 @@ impl Thread { fn handle_tool_use_event( &mut self, tool_use: LanguageModelToolUse, - events_tx: &mpsc::UnboundedSender>, + event_stream: &AgentResponseEventStream, cx: &mut Context, ) -> Option> { cx.notify(); @@ -446,32 +444,18 @@ impl Thread { } }); if push_new_tool_use { - events_tx - .unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall { - id: acp::ToolCallId(tool_use.id.to_string().into()), - title: tool_use.name.to_string(), - kind: acp::ToolKind::Other, - status: acp::ToolCallStatus::Pending, - content: vec![], - locations: vec![], - raw_input: Some(tool_use.input.clone()), - }))) - .ok(); + event_stream.send_tool_call(&tool_use); last_message .content .push(MessageContent::ToolUse(tool_use.clone())); } else { - events_tx - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( - acp::ToolCallUpdate { - id: acp::ToolCallId(tool_use.id.to_string().into()), - fields: acp::ToolCallUpdateFields { - raw_input: Some(tool_use.input.clone()), - ..Default::default() - }, - }, - ))) - .ok(); + event_stream.send_tool_call_update( + &tool_use.id, + acp::ToolCallUpdateFields { + raw_input: Some(tool_use.input.clone()), + ..Default::default() + }, + ); } if !tool_use.is_input_complete { @@ -479,22 +463,10 @@ impl Thread { } if let Some(tool) = self.tools.get(tool_use.name.as_ref()) { - events_tx - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( - acp::ToolCallUpdate { - id: acp::ToolCallId(tool_use.id.to_string().into()), - fields: acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::InProgress), - ..Default::default() - }, - }, - ))) - .ok(); - - let pending_tool_result = tool.clone().run(tool_use.input, cx); - + let tool_result = + self.run_tool(tool.clone(), tool_use.clone(), event_stream.clone(), cx); Some(cx.foreground_executor().spawn(async move { - match pending_tool_result.await { + match tool_result.await { Ok(tool_output) => LanguageModelToolResult { tool_use_id: tool_use.id, tool_name: tool_use.name, @@ -523,6 +495,30 @@ impl Thread { } } + fn run_tool( + &self, + tool: Arc, + tool_use: LanguageModelToolUse, + event_stream: AgentResponseEventStream, + cx: &mut Context, + ) -> Task> { + let needs_authorization = tool.needs_authorization(tool_use.input.clone(), cx); + cx.spawn(async move |_this, cx| { + if needs_authorization? { + event_stream.authorize_tool_call(&tool_use).await?; + } + + event_stream.send_tool_call_update( + &tool_use.id, + acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress), + ..Default::default() + }, + ); + cx.update(|cx| tool.run(tool_use.input, cx))?.await + }) + } + fn handle_tool_use_json_parse_error_event( &mut self, tool_use_id: LanguageModelToolUseId, @@ -575,7 +571,7 @@ impl Thread { log::debug!("Completion intent: {:?}", completion_intent); log::debug!("Completion mode: {:?}", self.completion_mode); - let messages = self.build_request_messages(cx); + let messages = self.build_request_messages(); log::info!("Request will include {} messages", messages.len()); let tools: Vec = self @@ -613,14 +609,13 @@ impl Thread { request } - fn build_request_messages(&self, cx: &App) -> Vec { + fn build_request_messages(&self) -> Vec { log::trace!( "Building request messages from {} thread messages", self.messages.len() ); - let messages = self - .build_system_message(cx) + let messages = Some(self.build_system_message()) .iter() .chain(self.messages.iter()) .map(|message| { @@ -674,6 +669,10 @@ where schemars::schema_for!(Self::Input) } + /// Returns true if the tool needs the users's authorization + /// before running. + fn needs_authorization(&self, input: Self::Input, cx: &App) -> bool; + /// Runs the tool with the provided input. fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task>; @@ -688,6 +687,7 @@ pub trait AnyAgentTool { fn name(&self) -> SharedString; fn description(&self, cx: &mut App) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result; fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task>; } @@ -707,6 +707,14 @@ where Ok(serde_json::to_value(self.0.input_schema(format))?) } + fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result { + let parsed_input: Result = serde_json::from_value(input).map_err(Into::into); + match parsed_input { + Ok(input) => Ok(self.0.needs_authorization(input, cx)), + Err(error) => Err(anyhow!(error)), + } + } + fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task> { let parsed_input: Result = serde_json::from_value(input).map_err(Into::into); match parsed_input { @@ -716,39 +724,153 @@ where } } -fn to_acp_stop_reason(reason: StopReason) -> Option { - match reason { - StopReason::EndTurn => Some(acp::StopReason::EndTurn), - StopReason::MaxTokens => Some(acp::StopReason::MaxTokens), - StopReason::Refusal => Some(acp::StopReason::Refusal), - StopReason::ToolUse => None, - } -} +#[derive(Clone)] +struct AgentResponseEventStream( + mpsc::UnboundedSender>, +); -fn to_acp_tool_call_update(tool_result: &LanguageModelToolResult) -> acp::ToolCallUpdate { - let status = if tool_result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - }; - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => text.to_string().into(), - LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => { - acp::ToolCallContent::Content { - content: acp::ContentBlock::Image(acp::ImageContent { - annotations: None, - data: source.to_string(), - mime_type: ImageFormat::Png.mime_type().to_string(), - }), +impl AgentResponseEventStream { + fn send_text(&self, text: &str) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string()))) + .ok(); + } + + fn send_thinking(&self, text: &str) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string()))) + .ok(); + } + + fn authorize_tool_call( + &self, + tool_use: &LanguageModelToolUse, + ) -> impl use<> + Future> { + let (response_tx, response_rx) = oneshot::channel(); + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( + ToolCallAuthorization { + tool_call: acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + title: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(tool_use.input.clone()), + }, + options: vec![ + acp::PermissionOption { + id: acp::PermissionOptionId("always_allow".into()), + name: "Always Allow".into(), + kind: acp::PermissionOptionKind::AllowAlways, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("allow".into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("deny".into()), + name: "Deny".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + response: response_tx, + }, + ))) + .ok(); + async move { + match response_rx.await?.0.as_ref() { + "allow" | "always_allow" => Ok(()), + _ => Err(anyhow!("Permission to run tool denied by user")), } } - }; - acp::ToolCallUpdate { - id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()), - fields: acp::ToolCallUpdateFields { - status: Some(status), - content: Some(vec![content]), - ..Default::default() - }, + } + + fn send_tool_call(&self, tool_use: &LanguageModelToolUse) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + title: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(tool_use.input.clone()), + }))) + .ok(); + } + + fn send_tool_call_update( + &self, + tool_use_id: &LanguageModelToolUseId, + fields: acp::ToolCallUpdateFields, + ) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_use_id.to_string().into()), + fields, + }, + ))) + .ok(); + } + + fn send_tool_call_result(&self, tool_result: &LanguageModelToolResult) { + let status = if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }; + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => text.to_string().into(), + LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => { + acp::ToolCallContent::Content { + content: acp::ContentBlock::Image(acp::ImageContent { + annotations: None, + data: source.to_string(), + mime_type: ImageFormat::Png.mime_type().to_string(), + }), + } + } + }; + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + acp::ToolCallUpdate { + id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()), + fields: acp::ToolCallUpdateFields { + status: Some(status), + content: Some(vec![content]), + ..Default::default() + }, + }, + ))) + .ok(); + } + + fn send_stop(&self, reason: StopReason) { + match reason { + StopReason::EndTurn => { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn))) + .ok(); + } + StopReason::MaxTokens => { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens))) + .ok(); + } + StopReason::Refusal => { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal))) + .ok(); + } + StopReason::ToolUse => {} + } + } + + fn send_error(&self, error: LanguageModelCompletionError) { + self.0.unbounded_send(Err(error)).ok(); } } diff --git a/crates/agent2/src/tools/glob.rs b/crates/agent2/src/tools/glob.rs index 9434311aaf..f44ce9f359 100644 --- a/crates/agent2/src/tools/glob.rs +++ b/crates/agent2/src/tools/glob.rs @@ -46,6 +46,10 @@ impl AgentTool for GlobTool { .into() } + fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { + false + } + fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> { let path_matcher = match PathMatcher::new([&input.glob]) { Ok(matcher) => matcher, diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index c0b64fcc41..e676b7ee46 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -135,7 +135,7 @@ impl acp_old::Client for OldAcpClientDelegate { let response = cx .update(|cx| { self.thread.borrow().update(cx, |thread, cx| { - thread.request_tool_call_permission(tool_call, acp_options, cx) + thread.request_tool_call_authorization(tool_call, acp_options, cx) }) })? .context("Failed to update thread")? diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index 178796816a..ff71783b48 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -210,7 +210,7 @@ impl acp::Client for ClientDelegate { .context("Failed to get session")? .thread .update(cx, |thread, cx| { - thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx) + thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx) })?; let result = rx.await; diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index c6f8bb5b69..53a8556e74 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -153,7 +153,7 @@ impl McpServerTool for PermissionTool { let chosen_option = thread .update(cx, |thread, cx| { - thread.request_tool_call_permission( + thread.request_tool_call_authorization( claude_tool.as_acp(tool_call_id), vec![ acp::PermissionOption { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index a9b39e6cea..06e47a11dc 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -3147,7 +3147,7 @@ mod tests { let task = cx.spawn(async move |cx| { if let Some((tool_call, options)) = permission_request { let permission = thread.update(cx, |thread, cx| { - thread.request_tool_call_permission( + thread.request_tool_call_authorization( tool_call.clone(), options.clone(), cx, diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 57fdc51336..90bb2e9b7c 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -36,13 +36,12 @@ use crate::delete_path_tool::DeletePathTool; use crate::diagnostics_tool::DiagnosticsTool; use crate::edit_file_tool::EditFileTool; use crate::fetch_tool::FetchTool; -use crate::find_path_tool::FindPathTool; use crate::list_directory_tool::ListDirectoryTool; use crate::now_tool::NowTool; use crate::thinking_tool::ThinkingTool; pub use edit_file_tool::{EditFileMode, EditFileToolInput}; -pub use find_path_tool::FindPathToolInput; +pub use find_path_tool::*; pub use grep_tool::{GrepTool, GrepToolInput}; pub use open_tool::OpenTool; pub use project_notifications_tool::ProjectNotificationsTool; diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index d737ef9246..7eb63eec5e 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -18,7 +18,7 @@ use util::{ResultExt, get_system_shell}; use crate::UserPromptId; -#[derive(Debug, Clone, Serialize)] +#[derive(Default, Debug, Clone, Serialize)] pub struct ProjectContext { pub worktrees: Vec, /// Whether any worktree has a rules_file. Provided as a field because handlebars can't do this. @@ -71,14 +71,14 @@ pub struct UserRulesContext { pub contents: String, } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Serialize)] pub struct WorktreeContext { pub root_name: String, pub abs_path: Arc, pub rules_file: Option, } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Serialize)] pub struct RulesFileContext { pub path_in_worktree: Arc, pub text: String,