Prompt before running some tools (#27284)

Also includes some fixes for how the Lua tool was being generated.

<img width="644" alt="Screenshot 2025-03-21 at 6 26 18 PM"
src="https://github.com/user-attachments/assets/51bd1685-5b3f-4ed3-b11e-6fa8017847d4"
/>


Release Notes:

- N/A

---------

Co-authored-by: Ben <ben@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Joseph T. Lyons <JosephTLyons@gmail.com>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Ben Kunkle <ben.kunkle@gmail.com>
This commit is contained in:
Richard Feldman 2025-03-22 00:05:34 -04:00 committed by GitHub
parent 90649fbc89
commit 4c86cda909
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 666 additions and 329 deletions

View file

@ -3,7 +3,7 @@ use crate::thread::{
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::tool_use::{PendingToolUseStatus, ToolType, ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
use collections::HashMap;
use editor::{Editor, MultiBuffer};
@ -471,11 +471,18 @@ impl ActiveThread {
for tool_use in tool_uses {
self.render_tool_use_label_markdown(
tool_use.id,
tool_use.id.clone(),
tool_use.ui_text.clone(),
window,
cx,
);
self.render_scripting_tool_use_markdown(
tool_use.id,
tool_use.name.as_ref(),
tool_use.input.clone(),
window,
cx,
);
}
}
ThreadEvent::ToolFinished {
@ -491,13 +498,6 @@ impl ActiveThread {
window,
cx,
);
self.render_scripting_tool_use_markdown(
tool_use.id.clone(),
tool_use.name.as_ref(),
tool_use.input.clone(),
window,
cx,
);
}
if self.thread.read(cx).all_tools_finished() {
@ -996,29 +996,31 @@ impl ActiveThread {
)
.child(div().p_2().child(message_content)),
),
Role::Assistant => v_flex()
.id(("message-container", ix))
.ml_2()
.pl_2()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.child(message_content)
.when(
!tool_uses.is_empty() || !scripting_tool_uses.is_empty(),
|parent| {
parent.child(
v_flex()
.children(
tool_uses
.into_iter()
.map(|tool_use| self.render_tool_use(tool_use, cx)),
)
.children(scripting_tool_uses.into_iter().map(|tool_use| {
self.render_scripting_tool_use(tool_use, window, cx)
})),
)
},
),
Role::Assistant => {
v_flex()
.id(("message-container", ix))
.ml_2()
.pl_2()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.child(message_content)
.when(
!tool_uses.is_empty() || !scripting_tool_uses.is_empty(),
|parent| {
parent.child(
v_flex()
.children(
tool_uses
.into_iter()
.map(|tool_use| self.render_tool_use(tool_use, cx)),
)
.children(scripting_tool_uses.into_iter().map(|tool_use| {
self.render_scripting_tool_use(tool_use, cx)
})),
)
},
)
}
Role::System => div().id(("message-container", ix)).py_1().px_2().child(
v_flex()
.bg(colors.editor_background)
@ -1379,7 +1381,8 @@ impl ActiveThread {
)
.child({
let (icon_name, color, animated) = match &tool_use.status {
ToolUseStatus::Pending => {
ToolUseStatus::Pending
| ToolUseStatus::NeedsConfirmation => {
(IconName::Warning, Color::Warning, false)
}
ToolUseStatus::Running => {
@ -1500,6 +1503,14 @@ impl ActiveThread {
),
),
ToolUseStatus::Pending => container,
ToolUseStatus::NeedsConfirmation => container.child(
content_container().child(
Label::new("Asking Permission")
.size(LabelSize::Small)
.color(Color::Muted)
.buffer_font(cx),
),
),
}),
)
}),
@ -1509,7 +1520,6 @@ impl ActiveThread {
fn render_scripting_tool_use(
&self,
tool_use: ToolUse,
window: &Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let is_open = self
@ -1555,13 +1565,25 @@ impl ActiveThread {
}
}),
))
.child(div().text_ui_sm(cx).child(render_markdown(
tool_use.ui_text.clone(),
self.language_registry.clone(),
window,
cx,
)))
.truncate(),
.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::Terminal)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
div()
.text_ui_sm(cx)
.children(
self.rendered_tool_use_labels
.get(&tool_use.id)
.cloned(),
)
.truncate(),
),
),
)
.child(
Label::new(match tool_use.status {
@ -1569,6 +1591,7 @@ impl ActiveThread {
ToolUseStatus::Running => "Running",
ToolUseStatus::Finished(_) => "Finished",
ToolUseStatus::Error(_) => "Error",
ToolUseStatus::NeedsConfirmation => "Asking Permission",
})
.size(LabelSize::XSmall)
.buffer_font(cx),
@ -1620,6 +1643,13 @@ impl ActiveThread {
.child(Label::new(err)),
),
ToolUseStatus::Pending | ToolUseStatus::Running => parent,
ToolUseStatus::NeedsConfirmation => parent.child(
v_flex()
.gap_0p5()
.py_1()
.px_2p5()
.child(Label::new("Asking Permission")),
),
}),
)
}),
@ -1682,6 +1712,45 @@ impl ActiveThread {
.into_any()
}
fn handle_allow_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
_: &ClickEvent,
_window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(PendingToolUseStatus::NeedsConfirmation(c)) = self
.thread
.read(cx)
.pending_tool(&tool_use_id)
.map(|tool_use| tool_use.status.clone())
{
self.thread.update(cx, |thread, cx| {
thread.run_tool(
c.tool_use_id.clone(),
c.ui_text.clone(),
c.input.clone(),
&c.messages,
c.tool_type.clone(),
cx,
);
});
}
}
fn handle_deny_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_type: ToolType,
_: &ClickEvent,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.thread.update(cx, |thread, cx| {
thread.deny_tool_use(tool_use_id, tool_type, cx);
});
}
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
@ -1704,12 +1773,82 @@ impl ActiveThread {
task.detach();
}
}
fn render_confirmations<'a>(
&'a mut self,
cx: &'a mut Context<Self>,
) -> impl Iterator<Item = AnyElement> + 'a {
let thread = self.thread.read(cx);
thread
.tools_needing_confirmation()
.map(|(tool_type, tool)| {
div()
.m_3()
.p_2()
.bg(cx.theme().colors().editor_background)
.border_1()
.border_color(cx.theme().colors().border)
.rounded_lg()
.child(
v_flex()
.gap_1()
.child(
v_flex()
.gap_0p5()
.child(
Label::new("The agent wants to run this action:")
.color(Color::Muted),
)
.child(div().p_3().child(Label::new(&tool.ui_text))),
)
.child(
h_flex()
.gap_1()
.child({
let tool_id = tool.id.clone();
Button::new("allow-tool-action", "Allow").on_click(
cx.listener(move |this, event, window, cx| {
this.handle_allow_tool(
tool_id.clone(),
event,
window,
cx,
)
}),
)
})
.child({
let tool_id = tool.id.clone();
Button::new("deny-tool", "Deny").on_click(cx.listener(
move |this, event, window, cx| {
this.handle_deny_tool(
tool_id.clone(),
tool_type.clone(),
event,
window,
cx,
)
},
))
}),
)
.child(
Label::new("Note: A future release will introduce a way to remember your answers to these. In the meantime, you can avoid these prompts by adding \"assistant\": { \"always_allow_tool_actions\": true } to your settings.json.")
.color(Color::Muted)
.size(LabelSize::Small),
),
)
.into_any()
})
}
}
impl Render for ActiveThread {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.size_full()
.child(list(self.list_state.clone()).flex_grow())
.children(self.render_confirmations(cx))
}
}

View file

@ -3,14 +3,15 @@ use std::io::Write;
use std::sync::Arc;
use anyhow::{Context as _, Result};
use assistant_tool::{ActionLog, ToolWorkingSet};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
@ -24,6 +25,7 @@ use prompt_store::{
};
use scripting_tool::{ScriptingSession, ScriptingTool};
use serde::{Deserialize, Serialize};
use settings::Settings;
use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
use uuid::Uuid;
@ -32,7 +34,7 @@ use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
use crate::tool_use::{PendingToolUse, PendingToolUseStatus, ToolType, ToolUse, ToolUseState};
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@ -350,6 +352,44 @@ impl Thread {
&self.tools
}
pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
self.tool_use
.pending_tool_uses()
.into_iter()
.find(|tool_use| &tool_use.id == id)
.or_else(|| {
self.scripting_tool_use
.pending_tool_uses()
.into_iter()
.find(|tool_use| &tool_use.id == id)
})
}
pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = (ToolType, &PendingToolUse)> {
self.tool_use
.pending_tool_uses()
.into_iter()
.filter_map(|tool_use| {
if let PendingToolUseStatus::NeedsConfirmation(confirmation) = &tool_use.status {
Some((confirmation.tool_type.clone(), tool_use))
} else {
None
}
})
.chain(
self.scripting_tool_use
.pending_tool_uses()
.into_iter()
.filter_map(|tool_use| {
if tool_use.status.needs_confirmation() {
Some((ToolType::ScriptingTool, tool_use))
} else {
None
}
}),
)
}
pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
self.checkpoints_by_message.get(&id).cloned()
}
@ -1178,6 +1218,7 @@ impl Thread {
cx: &mut Context<Self>,
) -> impl IntoIterator<Item = PendingToolUse> {
let request = self.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
let pending_tool_uses = self
.tool_use
.pending_tool_uses()
@ -1188,18 +1229,33 @@ impl Thread {
for tool_use in pending_tool_uses.iter() {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(
tool_use.input.clone(),
&request.messages,
self.project.clone(),
self.action_log.clone(),
cx,
);
self.insert_tool_output(
if tool.needs_confirmation()
&& !AssistantSettings::get_global(cx).always_allow_tool_actions
{
self.tool_use.confirm_tool_use(
tool_use.id.clone(),
tool_use.ui_text.clone(),
tool_use.input.clone(),
messages.clone(),
ToolType::NonScriptingTool(tool),
);
} else {
self.run_tool(
tool_use.id.clone(),
tool_use.ui_text.clone(),
tool_use.input.clone(),
&messages,
ToolType::NonScriptingTool(tool),
cx,
);
}
} else if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
self.run_tool(
tool_use.id.clone(),
tool_use.ui_text.clone().into(),
task,
tool_use.ui_text.clone(),
tool_use.input.clone(),
&messages,
ToolType::NonScriptingTool(tool),
cx,
);
}
@ -1214,36 +1270,13 @@ impl Thread {
.collect::<Vec<_>>();
for scripting_tool_use in pending_scripting_tool_uses.iter() {
let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) {
Err(err) => Task::ready(Err(err.into())),
Ok(input) => {
let (script_id, script_task) =
self.scripting_session.update(cx, move |session, cx| {
session.run_script(input.lua_script, cx)
});
let session = self.scripting_session.clone();
cx.spawn(async move |_, cx| {
script_task.await;
let message = session.read_with(cx, |session, _cx| {
// Using a id to get the script output seems impractical.
// Why not just include it in the Task result?
// This is because we'll later report the script state as it runs,
session
.get(script_id)
.output_message_for_llm()
.expect("Script shouldn't still be running")
})?;
Ok(message)
})
}
};
let ui_text: SharedString = scripting_tool_use.name.clone().into();
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx);
self.scripting_tool_use.confirm_tool_use(
scripting_tool_use.id.clone(),
scripting_tool_use.ui_text.clone(),
scripting_tool_use.input.clone(),
messages.clone(),
ToolType::ScriptingTool,
);
}
pending_tool_uses
@ -1251,17 +1284,49 @@ impl Thread {
.chain(pending_scripting_tool_uses)
}
pub fn insert_tool_output(
pub fn run_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
output: Task<Result<String>>,
cx: &mut Context<Self>,
ui_text: impl Into<SharedString>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
tool_type: ToolType,
cx: &mut Context<'_, Thread>,
) {
let insert_output_task = cx.spawn({
let tool_use_id = tool_use_id.clone();
async move |thread, cx| {
let output = output.await;
match tool_type {
ToolType::ScriptingTool => {
let task = self.spawn_scripting_tool_use(tool_use_id.clone(), input, cx);
self.scripting_tool_use
.run_pending_tool(tool_use_id, ui_text.into(), task);
}
ToolType::NonScriptingTool(tool) => {
let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
self.tool_use
.run_pending_tool(tool_use_id, ui_text.into(), task);
}
}
}
fn spawn_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
messages: &[LanguageModelRequestMessage],
input: serde_json::Value,
tool: Arc<dyn Tool>,
cx: &mut Context<Thread>,
) -> Task<()> {
let run_tool = tool.run(
input,
messages,
self.project.clone(),
self.action_log.clone(),
cx,
);
cx.spawn({
async move |thread: WeakEntity<Thread>, cx| {
let output = run_tool.await;
thread
.update(cx, |thread, cx| {
let pending_tool_use = thread
@ -1276,23 +1341,46 @@ impl Thread {
})
.ok();
}
});
self.tool_use
.run_pending_tool(tool_use_id, ui_text, insert_output_task);
})
}
pub fn insert_scripting_tool_output(
fn spawn_scripting_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
output: Task<Result<String>>,
cx: &mut Context<Self>,
) {
let insert_output_task = cx.spawn({
input: serde_json::Value,
cx: &mut Context<Thread>,
) -> Task<()> {
let task = match ScriptingTool::deserialize_input(input) {
Err(err) => Task::ready(Err(err.into())),
Ok(input) => {
let (script_id, script_task) =
self.scripting_session.update(cx, move |session, cx| {
session.run_script(input.lua_script, cx)
});
let session = self.scripting_session.clone();
cx.spawn(async move |_, cx| {
script_task.await;
let message = session.read_with(cx, |session, _cx| {
// Using a id to get the script output seems impractical.
// Why not just include it in the Task result?
// This is because we'll later report the script state as it runs,
session
.get(script_id)
.output_message_for_llm()
.expect("Script shouldn't still be running")
})?;
Ok(message)
})
}
};
cx.spawn({
let tool_use_id = tool_use_id.clone();
async move |thread, cx| {
let output = output.await;
let output = task.await;
thread
.update(cx, |thread, cx| {
let pending_tool_use = thread
@ -1307,10 +1395,7 @@ impl Thread {
})
.ok();
}
});
self.scripting_tool_use
.run_pending_tool(tool_use_id, ui_text, insert_output_task);
})
}
pub fn attach_tool_results(
@ -1568,6 +1653,30 @@ impl Thread {
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone()
}
pub fn deny_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_type: ToolType,
cx: &mut Context<Self>,
) {
let err = Err(anyhow::anyhow!(
"Permission to run tool action denied by user"
));
if let ToolType::ScriptingTool = tool_type {
self.scripting_tool_use
.insert_tool_output(tool_use_id.clone(), err);
} else {
self.tool_use.insert_tool_output(tool_use_id.clone(), err);
}
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use: None,
canceled: true,
});
}
}
#[derive(Debug, Clone)]

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use assistant_tool::{Tool, ToolWorkingSet};
use collections::HashMap;
use futures::future::Shared;
use futures::FutureExt as _;
@ -10,6 +10,7 @@ use language_model::{
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, Role,
};
use scripting_tool::ScriptingTool;
use crate::thread::MessageId;
use crate::thread_store::SerializedMessage;
@ -25,6 +26,7 @@ pub struct ToolUse {
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
@ -163,16 +165,19 @@ impl ToolUseState {
}
if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
return match pending_tool_use.status {
match pending_tool_use.status {
PendingToolUseStatus::Idle => ToolUseStatus::Pending,
PendingToolUseStatus::NeedsConfirmation { .. } => {
ToolUseStatus::NeedsConfirmation
}
PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
PendingToolUseStatus::Error(ref err) => {
ToolUseStatus::Error(err.clone().into())
}
};
}
} else {
ToolUseStatus::Pending
}
ToolUseStatus::Pending
})();
tool_uses.push(ToolUse {
@ -195,6 +200,8 @@ impl ToolUseState {
) -> SharedString {
if let Some(tool) = self.tools.tool(tool_name, cx) {
tool.ui_text(input).into()
} else if tool_name == ScriptingTool::NAME {
"Run Lua Script".into()
} else {
"Unknown tool".into()
}
@ -272,6 +279,28 @@ impl ToolUseState {
}
}
pub fn confirm_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: impl Into<Arc<str>>,
input: serde_json::Value,
messages: Arc<Vec<LanguageModelRequestMessage>>,
tool_type: ToolType,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
let ui_text = ui_text.into();
tool_use.ui_text = ui_text.clone();
let confirmation = Confirmation {
tool_use_id,
input,
messages,
tool_type,
ui_text,
};
tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
}
}
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
@ -369,9 +398,25 @@ pub struct PendingToolUse {
pub status: PendingToolUseStatus,
}
#[derive(Debug, Clone)]
pub enum ToolType {
ScriptingTool,
NonScriptingTool(Arc<dyn Tool>),
}
#[derive(Debug, Clone)]
pub struct Confirmation {
pub tool_use_id: LanguageModelToolUseId,
pub input: serde_json::Value,
pub ui_text: Arc<str>,
pub messages: Arc<Vec<LanguageModelRequestMessage>>,
pub tool_type: ToolType,
}
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
Idle,
NeedsConfirmation(Arc<Confirmation>),
Running { _task: Shared<Task<()>> },
Error(#[allow(unused)] Arc<str>),
}
@ -384,4 +429,8 @@ impl PendingToolUseStatus {
pub fn is_error(&self) -> bool {
matches!(self, PendingToolUseStatus::Error(_))
}
pub fn needs_confirmation(&self) -> bool {
matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
}
}