diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 762664e1b1..31fd798bef 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -3,7 +3,7 @@ use crate::thread::{ ThreadEvent, ThreadFeedback, }; use crate::thread_store::ThreadStore; -use crate::tool_use::{ToolUse, ToolUseStatus}; +use crate::tool_use::{PendingToolUseStatus, ToolType, ToolUse, ToolUseStatus}; use crate::ui::ContextPill; use collections::HashMap; use editor::{Editor, MultiBuffer}; @@ -471,11 +471,18 @@ impl ActiveThread { for tool_use in tool_uses { self.render_tool_use_label_markdown( - tool_use.id, + tool_use.id.clone(), tool_use.ui_text.clone(), window, cx, ); + self.render_scripting_tool_use_markdown( + tool_use.id, + tool_use.name.as_ref(), + tool_use.input.clone(), + window, + cx, + ); } } ThreadEvent::ToolFinished { @@ -491,13 +498,6 @@ impl ActiveThread { window, cx, ); - self.render_scripting_tool_use_markdown( - tool_use.id.clone(), - tool_use.name.as_ref(), - tool_use.input.clone(), - window, - cx, - ); } if self.thread.read(cx).all_tools_finished() { @@ -996,29 +996,31 @@ impl ActiveThread { ) .child(div().p_2().child(message_content)), ), - Role::Assistant => v_flex() - .id(("message-container", ix)) - .ml_2() - .pl_2() - .border_l_1() - .border_color(cx.theme().colors().border_variant) - .child(message_content) - .when( - !tool_uses.is_empty() || !scripting_tool_uses.is_empty(), - |parent| { - parent.child( - v_flex() - .children( - tool_uses - .into_iter() - .map(|tool_use| self.render_tool_use(tool_use, cx)), - ) - .children(scripting_tool_uses.into_iter().map(|tool_use| { - self.render_scripting_tool_use(tool_use, window, cx) - })), - ) - }, - ), + Role::Assistant => { + v_flex() + .id(("message-container", ix)) + .ml_2() + .pl_2() + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .child(message_content) + .when( + !tool_uses.is_empty() || !scripting_tool_uses.is_empty(), + |parent| { + parent.child( + v_flex() + .children( + tool_uses + .into_iter() + .map(|tool_use| self.render_tool_use(tool_use, cx)), + ) + .children(scripting_tool_uses.into_iter().map(|tool_use| { + self.render_scripting_tool_use(tool_use, cx) + })), + ) + }, + ) + } Role::System => div().id(("message-container", ix)).py_1().px_2().child( v_flex() .bg(colors.editor_background) @@ -1379,7 +1381,8 @@ impl ActiveThread { ) .child({ let (icon_name, color, animated) = match &tool_use.status { - ToolUseStatus::Pending => { + ToolUseStatus::Pending + | ToolUseStatus::NeedsConfirmation => { (IconName::Warning, Color::Warning, false) } ToolUseStatus::Running => { @@ -1500,6 +1503,14 @@ impl ActiveThread { ), ), ToolUseStatus::Pending => container, + ToolUseStatus::NeedsConfirmation => container.child( + content_container().child( + Label::new("Asking Permission") + .size(LabelSize::Small) + .color(Color::Muted) + .buffer_font(cx), + ), + ), }), ) }), @@ -1509,7 +1520,6 @@ impl ActiveThread { fn render_scripting_tool_use( &self, tool_use: ToolUse, - window: &Window, cx: &mut Context, ) -> impl IntoElement { let is_open = self @@ -1555,13 +1565,25 @@ impl ActiveThread { } }), )) - .child(div().text_ui_sm(cx).child(render_markdown( - tool_use.ui_text.clone(), - self.language_registry.clone(), - window, - cx, - ))) - .truncate(), + .child( + h_flex() + .gap_1p5() + .child( + Icon::new(IconName::Terminal) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child( + div() + .text_ui_sm(cx) + .children( + self.rendered_tool_use_labels + .get(&tool_use.id) + .cloned(), + ) + .truncate(), + ), + ), ) .child( Label::new(match tool_use.status { @@ -1569,6 +1591,7 @@ impl ActiveThread { ToolUseStatus::Running => "Running", ToolUseStatus::Finished(_) => "Finished", ToolUseStatus::Error(_) => "Error", + ToolUseStatus::NeedsConfirmation => "Asking Permission", }) .size(LabelSize::XSmall) .buffer_font(cx), @@ -1620,6 +1643,13 @@ impl ActiveThread { .child(Label::new(err)), ), ToolUseStatus::Pending | ToolUseStatus::Running => parent, + ToolUseStatus::NeedsConfirmation => parent.child( + v_flex() + .gap_0p5() + .py_1() + .px_2p5() + .child(Label::new("Asking Permission")), + ), }), ) }), @@ -1682,6 +1712,45 @@ impl ActiveThread { .into_any() } + fn handle_allow_tool( + &mut self, + tool_use_id: LanguageModelToolUseId, + _: &ClickEvent, + _window: &mut Window, + cx: &mut Context, + ) { + if let Some(PendingToolUseStatus::NeedsConfirmation(c)) = self + .thread + .read(cx) + .pending_tool(&tool_use_id) + .map(|tool_use| tool_use.status.clone()) + { + self.thread.update(cx, |thread, cx| { + thread.run_tool( + c.tool_use_id.clone(), + c.ui_text.clone(), + c.input.clone(), + &c.messages, + c.tool_type.clone(), + cx, + ); + }); + } + } + + fn handle_deny_tool( + &mut self, + tool_use_id: LanguageModelToolUseId, + tool_type: ToolType, + _: &ClickEvent, + _window: &mut Window, + cx: &mut Context, + ) { + self.thread.update(cx, |thread, cx| { + thread.deny_tool_use(tool_use_id, tool_type, cx); + }); + } + fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context) { let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref() else { @@ -1704,12 +1773,82 @@ impl ActiveThread { task.detach(); } } + + fn render_confirmations<'a>( + &'a mut self, + cx: &'a mut Context, + ) -> impl Iterator + 'a { + let thread = self.thread.read(cx); + + thread + .tools_needing_confirmation() + .map(|(tool_type, tool)| { + div() + .m_3() + .p_2() + .bg(cx.theme().colors().editor_background) + .border_1() + .border_color(cx.theme().colors().border) + .rounded_lg() + .child( + v_flex() + .gap_1() + .child( + v_flex() + .gap_0p5() + .child( + Label::new("The agent wants to run this action:") + .color(Color::Muted), + ) + .child(div().p_3().child(Label::new(&tool.ui_text))), + ) + .child( + h_flex() + .gap_1() + .child({ + let tool_id = tool.id.clone(); + Button::new("allow-tool-action", "Allow").on_click( + cx.listener(move |this, event, window, cx| { + this.handle_allow_tool( + tool_id.clone(), + event, + window, + cx, + ) + }), + ) + }) + .child({ + let tool_id = tool.id.clone(); + Button::new("deny-tool", "Deny").on_click(cx.listener( + move |this, event, window, cx| { + this.handle_deny_tool( + tool_id.clone(), + tool_type.clone(), + event, + window, + cx, + ) + }, + )) + }), + ) + .child( + Label::new("Note: A future release will introduce a way to remember your answers to these. In the meantime, you can avoid these prompts by adding \"assistant\": { \"always_allow_tool_actions\": true } to your settings.json.") + .color(Color::Muted) + .size(LabelSize::Small), + ), + ) + .into_any() + }) + } } impl Render for ActiveThread { - fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() .size_full() .child(list(self.list_state.clone()).flex_grow()) + .children(self.render_confirmations(cx)) } } diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 41834a6e5b..e090ac7fe6 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -3,14 +3,15 @@ use std::io::Write; use std::sync::Arc; use anyhow::{Context as _, Result}; -use assistant_tool::{ActionLog, ToolWorkingSet}; +use assistant_settings::AssistantSettings; +use assistant_tool::{ActionLog, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::{BTreeMap, HashMap, HashSet}; use fs::Fs; use futures::future::Shared; use futures::{FutureExt, StreamExt as _}; use git; -use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task}; +use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, @@ -24,6 +25,7 @@ use prompt_store::{ }; use scripting_tool::{ScriptingSession, ScriptingTool}; use serde::{Deserialize, Serialize}; +use settings::Settings; use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _}; use uuid::Uuid; @@ -32,7 +34,7 @@ use crate::thread_store::{ SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, SerializedToolUse, }; -use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState}; +use crate::tool_use::{PendingToolUse, PendingToolUseStatus, ToolType, ToolUse, ToolUseState}; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -350,6 +352,44 @@ impl Thread { &self.tools } + pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> { + self.tool_use + .pending_tool_uses() + .into_iter() + .find(|tool_use| &tool_use.id == id) + .or_else(|| { + self.scripting_tool_use + .pending_tool_uses() + .into_iter() + .find(|tool_use| &tool_use.id == id) + }) + } + + pub fn tools_needing_confirmation(&self) -> impl Iterator { + self.tool_use + .pending_tool_uses() + .into_iter() + .filter_map(|tool_use| { + if let PendingToolUseStatus::NeedsConfirmation(confirmation) = &tool_use.status { + Some((confirmation.tool_type.clone(), tool_use)) + } else { + None + } + }) + .chain( + self.scripting_tool_use + .pending_tool_uses() + .into_iter() + .filter_map(|tool_use| { + if tool_use.status.needs_confirmation() { + Some((ToolType::ScriptingTool, tool_use)) + } else { + None + } + }), + ) + } + pub fn checkpoint_for_message(&self, id: MessageId) -> Option { self.checkpoints_by_message.get(&id).cloned() } @@ -1178,6 +1218,7 @@ impl Thread { cx: &mut Context, ) -> impl IntoIterator { let request = self.to_completion_request(RequestKind::Chat, cx); + let messages = Arc::new(request.messages); let pending_tool_uses = self .tool_use .pending_tool_uses() @@ -1188,18 +1229,33 @@ impl Thread { for tool_use in pending_tool_uses.iter() { if let Some(tool) = self.tools.tool(&tool_use.name, cx) { - let task = tool.run( - tool_use.input.clone(), - &request.messages, - self.project.clone(), - self.action_log.clone(), - cx, - ); - - self.insert_tool_output( + if tool.needs_confirmation() + && !AssistantSettings::get_global(cx).always_allow_tool_actions + { + self.tool_use.confirm_tool_use( + tool_use.id.clone(), + tool_use.ui_text.clone(), + tool_use.input.clone(), + messages.clone(), + ToolType::NonScriptingTool(tool), + ); + } else { + self.run_tool( + tool_use.id.clone(), + tool_use.ui_text.clone(), + tool_use.input.clone(), + &messages, + ToolType::NonScriptingTool(tool), + cx, + ); + } + } else if let Some(tool) = self.tools.tool(&tool_use.name, cx) { + self.run_tool( tool_use.id.clone(), - tool_use.ui_text.clone().into(), - task, + tool_use.ui_text.clone(), + tool_use.input.clone(), + &messages, + ToolType::NonScriptingTool(tool), cx, ); } @@ -1214,36 +1270,13 @@ impl Thread { .collect::>(); for scripting_tool_use in pending_scripting_tool_uses.iter() { - let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) { - Err(err) => Task::ready(Err(err.into())), - Ok(input) => { - let (script_id, script_task) = - self.scripting_session.update(cx, move |session, cx| { - session.run_script(input.lua_script, cx) - }); - - let session = self.scripting_session.clone(); - cx.spawn(async move |_, cx| { - script_task.await; - - let message = session.read_with(cx, |session, _cx| { - // Using a id to get the script output seems impractical. - // Why not just include it in the Task result? - // This is because we'll later report the script state as it runs, - session - .get(script_id) - .output_message_for_llm() - .expect("Script shouldn't still be running") - })?; - - Ok(message) - }) - } - }; - - let ui_text: SharedString = scripting_tool_use.name.clone().into(); - - self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx); + self.scripting_tool_use.confirm_tool_use( + scripting_tool_use.id.clone(), + scripting_tool_use.ui_text.clone(), + scripting_tool_use.input.clone(), + messages.clone(), + ToolType::ScriptingTool, + ); } pending_tool_uses @@ -1251,17 +1284,49 @@ impl Thread { .chain(pending_scripting_tool_uses) } - pub fn insert_tool_output( + pub fn run_tool( &mut self, tool_use_id: LanguageModelToolUseId, - ui_text: SharedString, - output: Task>, - cx: &mut Context, + ui_text: impl Into, + input: serde_json::Value, + messages: &[LanguageModelRequestMessage], + tool_type: ToolType, + cx: &mut Context<'_, Thread>, ) { - let insert_output_task = cx.spawn({ - let tool_use_id = tool_use_id.clone(); - async move |thread, cx| { - let output = output.await; + match tool_type { + ToolType::ScriptingTool => { + let task = self.spawn_scripting_tool_use(tool_use_id.clone(), input, cx); + self.scripting_tool_use + .run_pending_tool(tool_use_id, ui_text.into(), task); + } + ToolType::NonScriptingTool(tool) => { + let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx); + self.tool_use + .run_pending_tool(tool_use_id, ui_text.into(), task); + } + } + } + + fn spawn_tool_use( + &mut self, + tool_use_id: LanguageModelToolUseId, + messages: &[LanguageModelRequestMessage], + input: serde_json::Value, + tool: Arc, + cx: &mut Context, + ) -> Task<()> { + let run_tool = tool.run( + input, + messages, + self.project.clone(), + self.action_log.clone(), + cx, + ); + + cx.spawn({ + async move |thread: WeakEntity, cx| { + let output = run_tool.await; + thread .update(cx, |thread, cx| { let pending_tool_use = thread @@ -1276,23 +1341,46 @@ impl Thread { }) .ok(); } - }); - - self.tool_use - .run_pending_tool(tool_use_id, ui_text, insert_output_task); + }) } - pub fn insert_scripting_tool_output( + fn spawn_scripting_tool_use( &mut self, tool_use_id: LanguageModelToolUseId, - ui_text: SharedString, - output: Task>, - cx: &mut Context, - ) { - let insert_output_task = cx.spawn({ + input: serde_json::Value, + cx: &mut Context, + ) -> Task<()> { + let task = match ScriptingTool::deserialize_input(input) { + Err(err) => Task::ready(Err(err.into())), + Ok(input) => { + let (script_id, script_task) = + self.scripting_session.update(cx, move |session, cx| { + session.run_script(input.lua_script, cx) + }); + + let session = self.scripting_session.clone(); + cx.spawn(async move |_, cx| { + script_task.await; + + let message = session.read_with(cx, |session, _cx| { + // Using a id to get the script output seems impractical. + // Why not just include it in the Task result? + // This is because we'll later report the script state as it runs, + session + .get(script_id) + .output_message_for_llm() + .expect("Script shouldn't still be running") + })?; + + Ok(message) + }) + } + }; + + cx.spawn({ let tool_use_id = tool_use_id.clone(); async move |thread, cx| { - let output = output.await; + let output = task.await; thread .update(cx, |thread, cx| { let pending_tool_use = thread @@ -1307,10 +1395,7 @@ impl Thread { }) .ok(); } - }); - - self.scripting_tool_use - .run_pending_tool(tool_use_id, ui_text, insert_output_task); + }) } pub fn attach_tool_results( @@ -1568,6 +1653,30 @@ impl Thread { pub fn cumulative_token_usage(&self) -> TokenUsage { self.cumulative_token_usage.clone() } + + pub fn deny_tool_use( + &mut self, + tool_use_id: LanguageModelToolUseId, + tool_type: ToolType, + cx: &mut Context, + ) { + let err = Err(anyhow::anyhow!( + "Permission to run tool action denied by user" + )); + + if let ToolType::ScriptingTool = tool_type { + self.scripting_tool_use + .insert_tool_output(tool_use_id.clone(), err); + } else { + self.tool_use.insert_tool_output(tool_use_id.clone(), err); + } + + cx.emit(ThreadEvent::ToolFinished { + tool_use_id, + pending_tool_use: None, + canceled: true, + }); + } } #[derive(Debug, Clone)] diff --git a/crates/assistant2/src/tool_use.rs b/crates/assistant2/src/tool_use.rs index d020eab72b..7b7399c950 100644 --- a/crates/assistant2/src/tool_use.rs +++ b/crates/assistant2/src/tool_use.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use anyhow::Result; -use assistant_tool::ToolWorkingSet; +use assistant_tool::{Tool, ToolWorkingSet}; use collections::HashMap; use futures::future::Shared; use futures::FutureExt as _; @@ -10,6 +10,7 @@ use language_model::{ LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, }; +use scripting_tool::ScriptingTool; use crate::thread::MessageId; use crate::thread_store::SerializedMessage; @@ -25,6 +26,7 @@ pub struct ToolUse { #[derive(Debug, Clone)] pub enum ToolUseStatus { + NeedsConfirmation, Pending, Running, Finished(SharedString), @@ -163,16 +165,19 @@ impl ToolUseState { } if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) { - return match pending_tool_use.status { + match pending_tool_use.status { PendingToolUseStatus::Idle => ToolUseStatus::Pending, + PendingToolUseStatus::NeedsConfirmation { .. } => { + ToolUseStatus::NeedsConfirmation + } PendingToolUseStatus::Running { .. } => ToolUseStatus::Running, PendingToolUseStatus::Error(ref err) => { ToolUseStatus::Error(err.clone().into()) } - }; + } + } else { + ToolUseStatus::Pending } - - ToolUseStatus::Pending })(); tool_uses.push(ToolUse { @@ -195,6 +200,8 @@ impl ToolUseState { ) -> SharedString { if let Some(tool) = self.tools.tool(tool_name, cx) { tool.ui_text(input).into() + } else if tool_name == ScriptingTool::NAME { + "Run Lua Script".into() } else { "Unknown tool".into() } @@ -272,6 +279,28 @@ impl ToolUseState { } } + pub fn confirm_tool_use( + &mut self, + tool_use_id: LanguageModelToolUseId, + ui_text: impl Into>, + input: serde_json::Value, + messages: Arc>, + tool_type: ToolType, + ) { + if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { + let ui_text = ui_text.into(); + tool_use.ui_text = ui_text.clone(); + let confirmation = Confirmation { + tool_use_id, + input, + messages, + tool_type, + ui_text, + }; + tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation)); + } + } + pub fn insert_tool_output( &mut self, tool_use_id: LanguageModelToolUseId, @@ -369,9 +398,25 @@ pub struct PendingToolUse { pub status: PendingToolUseStatus, } +#[derive(Debug, Clone)] +pub enum ToolType { + ScriptingTool, + NonScriptingTool(Arc), +} + +#[derive(Debug, Clone)] +pub struct Confirmation { + pub tool_use_id: LanguageModelToolUseId, + pub input: serde_json::Value, + pub ui_text: Arc, + pub messages: Arc>, + pub tool_type: ToolType, +} + #[derive(Debug, Clone)] pub enum PendingToolUseStatus { Idle, + NeedsConfirmation(Arc), Running { _task: Shared> }, Error(#[allow(unused)] Arc), } @@ -384,4 +429,8 @@ impl PendingToolUseStatus { pub fn is_error(&self) -> bool { matches!(self, PendingToolUseStatus::Error(_)) } + + pub fn needs_confirmation(&self) -> bool { + matches!(self, PendingToolUseStatus::NeedsConfirmation { .. }) + } } diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 48607f888c..1e37ef91c0 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -72,6 +72,7 @@ pub struct AssistantSettings { pub using_outdated_settings_version: bool, pub enable_experimental_live_diffs: bool, pub profiles: IndexMap, AgentProfile>, + pub always_allow_tool_actions: bool, } impl AssistantSettings { @@ -173,6 +174,7 @@ impl AssistantSettingsContent { inline_alternatives: None, enable_experimental_live_diffs: None, profiles: None, + always_allow_tool_actions: None, }, VersionedAssistantSettingsContent::V2(settings) => settings.clone(), }, @@ -195,6 +197,7 @@ impl AssistantSettingsContent { inline_alternatives: None, enable_experimental_live_diffs: None, profiles: None, + always_allow_tool_actions: None, }, } } @@ -325,6 +328,7 @@ impl Default for VersionedAssistantSettingsContent { inline_alternatives: None, enable_experimental_live_diffs: None, profiles: None, + always_allow_tool_actions: None, }) } } @@ -363,6 +367,11 @@ pub struct AssistantSettingsContentV2 { enable_experimental_live_diffs: Option, #[schemars(skip)] profiles: Option, AgentProfileContent>>, + /// Whenever a tool action would normally wait for your confirmation + /// that you allow it, always choose to allow it. + /// + /// Default: false + always_allow_tool_actions: Option, } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] @@ -499,6 +508,10 @@ impl Settings for AssistantSettings { &mut settings.enable_experimental_live_diffs, value.enable_experimental_live_diffs, ); + merge( + &mut settings.always_allow_tool_actions, + value.always_allow_tool_actions, + ); if let Some(profiles) = value.profiles { settings @@ -579,6 +592,7 @@ mod tests { default_height: None, enable_experimental_live_diffs: None, profiles: None, + always_allow_tool_actions: None, }), ) }, diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 4db31d79d4..20fdb05439 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -1,14 +1,14 @@ mod tool_registry; mod tool_working_set; -use std::sync::Arc; - use anyhow::Result; use collections::{HashMap, HashSet}; use gpui::{App, Context, Entity, SharedString, Task}; use language::Buffer; use language_model::LanguageModelRequestMessage; use project::Project; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; pub use crate::tool_registry::*; pub use crate::tool_working_set::*; @@ -38,6 +38,10 @@ pub trait Tool: 'static + Send + Sync { ToolSource::Native } + /// Returns true iff the tool needs the users's confirmation + /// before having permission to run. + fn needs_confirmation(&self) -> bool; + /// Returns the JSON schema that describes the tool's input. fn input_schema(&self) -> serde_json::Value { serde_json::Value::Object(serde_json::Map::default()) @@ -57,6 +61,12 @@ pub trait Tool: 'static + Send + Sync { ) -> Task>; } +impl Debug for dyn Tool { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Tool").field("name", &self.name()).finish() + } +} + /// Tracks actions performed by tools in a thread #[derive(Debug)] pub struct ActionLog { diff --git a/crates/assistant_tools/src/bash_tool.rs b/crates/assistant_tools/src/bash_tool.rs index 4648da2881..51c8a024b3 100644 --- a/crates/assistant_tools/src/bash_tool.rs +++ b/crates/assistant_tools/src/bash_tool.rs @@ -23,6 +23,10 @@ impl Tool for BashTool { "bash".to_string() } + fn needs_confirmation(&self) -> bool { + true + } + fn description(&self) -> String { include_str!("./bash_tool/description.md").to_string() } diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index 12b1a6d33f..cc13d34e80 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -30,6 +30,10 @@ impl Tool for DeletePathTool { "delete-path".into() } + fn needs_confirmation(&self) -> bool { + true + } + fn description(&self) -> String { include_str!("./delete_path_tool/description.md").into() } diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index c7036ff9cb..95aec472a5 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -37,6 +37,10 @@ impl Tool for DiagnosticsTool { "diagnostics".into() } + fn needs_confirmation(&self) -> bool { + false + } + fn description(&self) -> String { include_str!("./diagnostics_tool/description.md").into() } diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index 4aff41f289..dad870851f 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -79,6 +79,10 @@ impl Tool for EditFilesTool { "edit-files".into() } + fn needs_confirmation(&self) -> bool { + true + } + fn description(&self) -> String { include_str!("./edit_files_tool/description.md").into() } @@ -145,30 +149,22 @@ impl Tool for EditFilesTool { struct EditToolRequest { parser: EditActionParser, - editor_response: EditorResponse, + output: String, + changed_buffers: HashSet>, + bad_searches: Vec, project: Entity, action_log: Entity, tool_log: Option<(Entity, EditToolRequestId)>, } -enum EditorResponse { - /// The editor model hasn't produced any actions yet. - /// If we don't have any by the end, we'll return its message to the architect model. - Message(String), - /// The editor model produced at least one action. - Actions { - applied: Vec, - search_errors: Vec, - }, -} - -struct AppliedAction { - source: String, - buffer: Entity, +#[derive(Debug)] +enum DiffResult { + BadSearch(BadSearch), + Diff(language::Diff), } #[derive(Debug)] -enum SearchError { +enum BadSearch { NoMatch { file_path: String, search: String, @@ -234,7 +230,10 @@ impl EditToolRequest { let mut request = Self { parser: EditActionParser::new(), - editor_response: EditorResponse::Message(String::with_capacity(256)), + // we start with the success header so we don't need to shift the output in the common case + output: Self::SUCCESS_OUTPUT_HEADER.to_string(), + changed_buffers: HashSet::default(), + bad_searches: Vec::new(), action_log, project, tool_log, @@ -251,12 +250,6 @@ impl EditToolRequest { async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> { let new_actions = self.parser.parse_chunk(chunk); - if let EditorResponse::Message(ref mut message) = self.editor_response { - if new_actions.is_empty() { - message.push_str(chunk); - } - } - if let Some((ref log, req_id)) = self.tool_log { log.update(cx, |log, cx| { log.push_editor_response_chunk(req_id, chunk, &new_actions, cx) @@ -287,11 +280,6 @@ impl EditToolRequest { .update(cx, |project, cx| project.open_buffer(project_path, cx))? .await?; - enum DiffResult { - Diff(language::Diff), - SearchError(SearchError), - } - let result = match action { EditAction::Replace { old, @@ -301,39 +289,7 @@ impl EditToolRequest { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; cx.background_executor() - .spawn(async move { - if snapshot.is_empty() { - let exists = snapshot - .file() - .map_or(false, |file| file.disk_state().exists()); - - let error = SearchError::EmptyBuffer { - file_path: file_path.display().to_string(), - exists, - search: old, - }; - - return anyhow::Ok(DiffResult::SearchError(error)); - } - - let replace_result = - // Try to match exactly - replace_exact(&old, &new, &snapshot) - .await - // If that fails, try being flexible about indentation - .or_else(|| replace_with_flexible_indent(&old, &new, &snapshot)); - - let Some(diff) = replace_result else { - let error = SearchError::NoMatch { - search: old, - file_path: file_path.display().to_string(), - }; - - return Ok(DiffResult::SearchError(error)); - }; - - Ok(DiffResult::Diff(diff)) - }) + .spawn(Self::replace_diff(old, new, file_path, snapshot)) .await } EditAction::Write { content, .. } => Ok(DiffResult::Diff( @@ -344,179 +300,177 @@ impl EditToolRequest { }?; match result { - DiffResult::SearchError(error) => { - self.push_search_error(error); + DiffResult::BadSearch(invalid_replace) => { + self.bad_searches.push(invalid_replace); } DiffResult::Diff(diff) => { let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?; - self.push_applied_action(AppliedAction { source, buffer }); + write!(&mut self.output, "\n\n{}", source)?; + self.changed_buffers.insert(buffer); } } - anyhow::Ok(()) + Ok(()) } - fn push_search_error(&mut self, error: SearchError) { - match &mut self.editor_response { - EditorResponse::Message(_) => { - self.editor_response = EditorResponse::Actions { - applied: Vec::new(), - search_errors: vec![error], - }; - } - EditorResponse::Actions { search_errors, .. } => { - search_errors.push(error); - } + async fn replace_diff( + old: String, + new: String, + file_path: std::path::PathBuf, + snapshot: language::BufferSnapshot, + ) -> Result { + if snapshot.is_empty() { + let exists = snapshot + .file() + .map_or(false, |file| file.disk_state().exists()); + + return Ok(DiffResult::BadSearch(BadSearch::EmptyBuffer { + file_path: file_path.display().to_string(), + exists, + search: old, + })); } + + let result = + // Try to match exactly + replace_exact(&old, &new, &snapshot) + .await + // If that fails, try being flexible about indentation + .or_else(|| replace_with_flexible_indent(&old, &new, &snapshot)); + + let Some(diff) = result else { + return anyhow::Ok(DiffResult::BadSearch(BadSearch::NoMatch { + search: old, + file_path: file_path.display().to_string(), + })); + }; + + anyhow::Ok(DiffResult::Diff(diff)) } - fn push_applied_action(&mut self, action: AppliedAction) { - match &mut self.editor_response { - EditorResponse::Message(_) => { - self.editor_response = EditorResponse::Actions { - applied: vec![action], - search_errors: Vec::new(), - }; - } - EditorResponse::Actions { applied, .. } => { - applied.push(action); - } - } - } + const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:"; + const ERROR_OUTPUT_HEADER_NO_EDITS: &str = "I couldn't apply any edits!"; + const ERROR_OUTPUT_HEADER_WITH_EDITS: &str = + "Errors occurred. First, here's a list of the edits we managed to apply:"; async fn finalize(self, cx: &mut AsyncApp) -> Result { - match self.editor_response { - EditorResponse::Message(message) => Err(anyhow!( - "No edits were applied! You might need to provide more context.\n\n{}", - message - )), - EditorResponse::Actions { - applied, - search_errors, - } => { - let mut output = String::with_capacity(1024); + let changed_buffer_count = self.changed_buffers.len(); - let parse_errors = self.parser.errors(); - let has_errors = !search_errors.is_empty() || !parse_errors.is_empty(); + // Save each buffer once at the end + for buffer in &self.changed_buffers { + self.project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? + .await?; + } - if has_errors { - let error_count = search_errors.len() + parse_errors.len(); + self.action_log + .update(cx, |log, cx| log.buffer_edited(self.changed_buffers, cx)) + .log_err(); - if applied.is_empty() { - writeln!( - &mut output, - "{} errors occurred! No edits were applied.", - error_count, - )?; - } else { - writeln!( - &mut output, - "{} errors occurred, but {} edits were correctly applied.", - error_count, - applied.len(), - )?; + let errors = self.parser.errors(); - writeln!( - &mut output, - "# {} SEARCH/REPLACE block(s) applied:\n\nDo not re-send these since they are already applied!\n", - applied.len() - )?; - } - } else { - write!( - &mut output, - "Successfully applied! Here's a list of applied edits:" - )?; - } + if errors.is_empty() && self.bad_searches.is_empty() { + if changed_buffer_count == 0 { + return Err(anyhow!( + "The instructions didn't lead to any changes. You might need to consult the file contents first." + )); + } - let mut changed_buffers = HashSet::default(); + Ok(self.output) + } else { + let mut output = self.output; - for action in applied { - changed_buffers.insert(action.buffer); - write!(&mut output, "\n\n{}", action.source)?; - } + if output.is_empty() { + output.replace_range( + 0..Self::SUCCESS_OUTPUT_HEADER.len(), + Self::ERROR_OUTPUT_HEADER_NO_EDITS, + ); + } else { + output.replace_range( + 0..Self::SUCCESS_OUTPUT_HEADER.len(), + Self::ERROR_OUTPUT_HEADER_WITH_EDITS, + ); + } - for buffer in &changed_buffers { - self.project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? - .await?; - } + if !self.bad_searches.is_empty() { + writeln!( + &mut output, + "\n\n# {} SEARCH/REPLACE block(s) failed to match:\n", + self.bad_searches.len() + )?; - self.action_log - .update(cx, |log, cx| log.buffer_edited(changed_buffers.clone(), cx)) - .log_err(); - - if !search_errors.is_empty() { - writeln!( - &mut output, - "\n\n## {} SEARCH/REPLACE block(s) failed to match:\n", - search_errors.len() - )?; - - for error in search_errors { - match error { - SearchError::NoMatch { file_path, search } => { - writeln!( - &mut output, - "### No exact match in: `{}`\n```\n{}\n```\n", - file_path, search, - )?; - } - SearchError::EmptyBuffer { - file_path, - exists: true, - search, - } => { - writeln!( - &mut output, - "### No match because `{}` is empty:\n```\n{}\n```\n", - file_path, search, - )?; - } - SearchError::EmptyBuffer { - file_path, - exists: false, - search, - } => { - writeln!( - &mut output, - "### No match because `{}` does not exist:\n```\n{}\n```\n", - file_path, search, - )?; - } + for bad_search in self.bad_searches { + match bad_search { + BadSearch::NoMatch { file_path, search } => { + writeln!( + &mut output, + "## No exact match in: `{}`\n```\n{}\n```\n", + file_path, search, + )?; + } + BadSearch::EmptyBuffer { + file_path, + exists: true, + search, + } => { + writeln!( + &mut output, + "## No match because `{}` is empty:\n```\n{}\n```\n", + file_path, search, + )?; + } + BadSearch::EmptyBuffer { + file_path, + exists: false, + search, + } => { + writeln!( + &mut output, + "## No match because `{}` does not exist:\n```\n{}\n```\n", + file_path, search, + )?; } } - - write!(&mut output, - "The SEARCH section must exactly match an existing block of lines including all white \ - space, comments, indentation, docstrings, etc." - )?; } - if !parse_errors.is_empty() { - writeln!( - &mut output, - "\n\n## {} SEARCH/REPLACE blocks failed to parse:", - parse_errors.len() - )?; + write!(&mut output, + "The SEARCH section must exactly match an existing block of lines including all white \ + space, comments, indentation, docstrings, etc." + )?; + } - for error in parse_errors { - writeln!(&mut output, "- {}", error)?; - } - } + if !errors.is_empty() { + writeln!( + &mut output, + "\n\n# {} SEARCH/REPLACE blocks failed to parse:", + errors.len() + )?; - if has_errors { - writeln!(&mut output, - "\n\nYou can fix errors by running the tool again. You can include instructions, \ - but errors are part of the conversation so you don't need to repeat them.", - )?; - - Err(anyhow!(output)) - } else { - Ok(output) + for error in errors { + writeln!(&mut output, "- {}", error)?; } } + + if changed_buffer_count > 0 { + writeln!( + &mut output, + "\n\nThe other SEARCH/REPLACE blocks were applied successfully. Do not re-send them!", + )?; + } + + writeln!( + &mut output, + "{}You can fix errors by running the tool again. You can include instructions, \ + but errors are part of the conversation so you don't need to repeat them.", + if changed_buffer_count == 0 { + "\n\n" + } else { + "" + } + )?; + + Err(anyhow!(output)) } } } diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index 1626d76ed2..9d9d9c75ff 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -113,6 +113,10 @@ impl Tool for FetchTool { "fetch".to_string() } + fn needs_confirmation(&self) -> bool { + true + } + fn description(&self) -> String { include_str!("./fetch_tool/description.md").to_string() } diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index d7501ddce7..813a65d450 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -31,7 +31,7 @@ pub struct ListDirectoryToolInput { /// /// If you wanna list contents in the directory `foo/baz`, you should use the path `foo/baz`. /// - pub path: Arc, + pub path: String, } pub struct ListDirectoryTool; @@ -41,6 +41,10 @@ impl Tool for ListDirectoryTool { "list-directory".into() } + fn needs_confirmation(&self) -> bool { + false + } + fn description(&self) -> String { include_str!("./list_directory_tool/description.md").into() } @@ -52,7 +56,7 @@ impl Tool for ListDirectoryTool { fn ui_text(&self, input: &serde_json::Value) -> String { match serde_json::from_value::(input.clone()) { - Ok(input) => format!("List the `{}` directory's contents", input.path.display()), + Ok(input) => format!("List the `{}` directory's contents", input.path), Err(_) => "List directory".to_string(), } } @@ -70,11 +74,29 @@ impl Tool for ListDirectoryTool { Err(err) => return Task::ready(Err(anyhow!(err))), }; + // Sometimes models will return these even though we tell it to give a path and not a glob. + // When this happens, just list the root worktree directories. + if matches!(input.path.as_str(), "." | "" | "./" | "*") { + let output = project + .read(cx) + .worktrees(cx) + .filter_map(|worktree| { + worktree.read(cx).root_entry().and_then(|entry| { + if entry.is_dir() { + entry.path.to_str() + } else { + None + } + }) + }) + .collect::>() + .join("\n"); + + return Task::ready(Ok(output)); + } + let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { - return Task::ready(Err(anyhow!( - "Path {} not found in project", - input.path.display() - ))); + return Task::ready(Err(anyhow!("Path {} not found in project", input.path))); }; let Some(worktree) = project .read(cx) @@ -85,11 +107,11 @@ impl Tool for ListDirectoryTool { let worktree = worktree.read(cx); let Some(entry) = worktree.entry_for_path(&project_path.path) else { - return Task::ready(Err(anyhow!("Path not found: {}", input.path.display()))); + return Task::ready(Err(anyhow!("Path not found: {}", input.path))); }; if !entry.is_dir() { - return Task::ready(Err(anyhow!("{} is not a directory.", input.path.display()))); + return Task::ready(Err(anyhow!("{} is not a directory.", input.path))); } let mut output = String::new(); @@ -102,7 +124,7 @@ impl Tool for ListDirectoryTool { .unwrap(); } if output.is_empty() { - return Task::ready(Ok(format!("{} is empty.", input.path.display()))); + return Task::ready(Ok(format!("{} is empty.", input.path))); } Task::ready(Ok(output)) } diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index dce45fe69b..04219b2bab 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -31,6 +31,10 @@ impl Tool for NowTool { "now".into() } + fn needs_confirmation(&self) -> bool { + false + } + fn description(&self) -> String { "Returns the current datetime in RFC 3339 format. Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime.".into() } diff --git a/crates/assistant_tools/src/path_search_tool.rs b/crates/assistant_tools/src/path_search_tool.rs index 4507bbadbc..e2621e3f96 100644 --- a/crates/assistant_tools/src/path_search_tool.rs +++ b/crates/assistant_tools/src/path_search_tool.rs @@ -39,6 +39,10 @@ impl Tool for PathSearchTool { "path-search".into() } + fn needs_confirmation(&self) -> bool { + false + } + fn description(&self) -> String { include_str!("./path_search_tool/description.md").into() } diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index c25ab6606a..550d4e64e4 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -44,6 +44,10 @@ impl Tool for ReadFileTool { "read-file".into() } + fn needs_confirmation(&self) -> bool { + false + } + fn description(&self) -> String { include_str!("./read_file_tool/description.md").into() } diff --git a/crates/assistant_tools/src/regex_search_tool.rs b/crates/assistant_tools/src/regex_search_tool.rs index 20630fddbc..57f91245c5 100644 --- a/crates/assistant_tools/src/regex_search_tool.rs +++ b/crates/assistant_tools/src/regex_search_tool.rs @@ -41,6 +41,10 @@ impl Tool for RegexSearchTool { "regex-search".into() } + fn needs_confirmation(&self) -> bool { + false + } + fn description(&self) -> String { include_str!("./regex_search_tool/description.md").into() } diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index d645c0fb5d..3d020c001a 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -22,6 +22,10 @@ impl Tool for ThinkingTool { "thinking".to_string() } + fn needs_confirmation(&self) -> bool { + false + } + fn description(&self) -> String { include_str!("./thinking_tool/description.md").to_string() } diff --git a/crates/context_server/src/context_server_tool.rs b/crates/context_server/src/context_server_tool.rs index f0bb36537f..9cc3be1f8a 100644 --- a/crates/context_server/src/context_server_tool.rs +++ b/crates/context_server/src/context_server_tool.rs @@ -44,6 +44,10 @@ impl Tool for ContextServerTool { } } + fn needs_confirmation(&self) -> bool { + true + } + fn input_schema(&self) -> serde_json::Value { match &self.tool.input_schema { serde_json::Value::Null => {