Prompt before running some tools (#27284)
Also includes some fixes for how the Lua tool was being generated. <img width="644" alt="Screenshot 2025-03-21 at 6 26 18 PM" src="https://github.com/user-attachments/assets/51bd1685-5b3f-4ed3-b11e-6fa8017847d4" /> Release Notes: - N/A --------- Co-authored-by: Ben <ben@zed.dev> Co-authored-by: Agus Zubiaga <agus@zed.dev> Co-authored-by: Joseph T. Lyons <JosephTLyons@gmail.com> Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Ben Kunkle <ben.kunkle@gmail.com>
This commit is contained in:
parent
90649fbc89
commit
4c86cda909
17 changed files with 666 additions and 329 deletions
|
@ -3,14 +3,15 @@ use std::io::Write;
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use assistant_tool::{ActionLog, ToolWorkingSet};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt, StreamExt as _};
|
||||
use git;
|
||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
|
||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
|
@ -24,6 +25,7 @@ use prompt_store::{
|
|||
};
|
||||
use scripting_tool::{ScriptingSession, ScriptingTool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -32,7 +34,7 @@ use crate::thread_store::{
|
|||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||
SerializedToolUse,
|
||||
};
|
||||
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
|
||||
use crate::tool_use::{PendingToolUse, PendingToolUseStatus, ToolType, ToolUse, ToolUseState};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum RequestKind {
|
||||
|
@ -350,6 +352,44 @@ impl Thread {
|
|||
&self.tools
|
||||
}
|
||||
|
||||
pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
|
||||
self.tool_use
|
||||
.pending_tool_uses()
|
||||
.into_iter()
|
||||
.find(|tool_use| &tool_use.id == id)
|
||||
.or_else(|| {
|
||||
self.scripting_tool_use
|
||||
.pending_tool_uses()
|
||||
.into_iter()
|
||||
.find(|tool_use| &tool_use.id == id)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = (ToolType, &PendingToolUse)> {
|
||||
self.tool_use
|
||||
.pending_tool_uses()
|
||||
.into_iter()
|
||||
.filter_map(|tool_use| {
|
||||
if let PendingToolUseStatus::NeedsConfirmation(confirmation) = &tool_use.status {
|
||||
Some((confirmation.tool_type.clone(), tool_use))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.chain(
|
||||
self.scripting_tool_use
|
||||
.pending_tool_uses()
|
||||
.into_iter()
|
||||
.filter_map(|tool_use| {
|
||||
if tool_use.status.needs_confirmation() {
|
||||
Some((ToolType::ScriptingTool, tool_use))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
|
||||
self.checkpoints_by_message.get(&id).cloned()
|
||||
}
|
||||
|
@ -1178,6 +1218,7 @@ impl Thread {
|
|||
cx: &mut Context<Self>,
|
||||
) -> impl IntoIterator<Item = PendingToolUse> {
|
||||
let request = self.to_completion_request(RequestKind::Chat, cx);
|
||||
let messages = Arc::new(request.messages);
|
||||
let pending_tool_uses = self
|
||||
.tool_use
|
||||
.pending_tool_uses()
|
||||
|
@ -1188,18 +1229,33 @@ impl Thread {
|
|||
|
||||
for tool_use in pending_tool_uses.iter() {
|
||||
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
||||
let task = tool.run(
|
||||
tool_use.input.clone(),
|
||||
&request.messages,
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
cx,
|
||||
);
|
||||
|
||||
self.insert_tool_output(
|
||||
if tool.needs_confirmation()
|
||||
&& !AssistantSettings::get_global(cx).always_allow_tool_actions
|
||||
{
|
||||
self.tool_use.confirm_tool_use(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
tool_use.input.clone(),
|
||||
messages.clone(),
|
||||
ToolType::NonScriptingTool(tool),
|
||||
);
|
||||
} else {
|
||||
self.run_tool(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
tool_use.input.clone(),
|
||||
&messages,
|
||||
ToolType::NonScriptingTool(tool),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
} else if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
|
||||
self.run_tool(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone().into(),
|
||||
task,
|
||||
tool_use.ui_text.clone(),
|
||||
tool_use.input.clone(),
|
||||
&messages,
|
||||
ToolType::NonScriptingTool(tool),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
@ -1214,36 +1270,13 @@ impl Thread {
|
|||
.collect::<Vec<_>>();
|
||||
|
||||
for scripting_tool_use in pending_scripting_tool_uses.iter() {
|
||||
let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) {
|
||||
Err(err) => Task::ready(Err(err.into())),
|
||||
Ok(input) => {
|
||||
let (script_id, script_task) =
|
||||
self.scripting_session.update(cx, move |session, cx| {
|
||||
session.run_script(input.lua_script, cx)
|
||||
});
|
||||
|
||||
let session = self.scripting_session.clone();
|
||||
cx.spawn(async move |_, cx| {
|
||||
script_task.await;
|
||||
|
||||
let message = session.read_with(cx, |session, _cx| {
|
||||
// Using a id to get the script output seems impractical.
|
||||
// Why not just include it in the Task result?
|
||||
// This is because we'll later report the script state as it runs,
|
||||
session
|
||||
.get(script_id)
|
||||
.output_message_for_llm()
|
||||
.expect("Script shouldn't still be running")
|
||||
})?;
|
||||
|
||||
Ok(message)
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let ui_text: SharedString = scripting_tool_use.name.clone().into();
|
||||
|
||||
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx);
|
||||
self.scripting_tool_use.confirm_tool_use(
|
||||
scripting_tool_use.id.clone(),
|
||||
scripting_tool_use.ui_text.clone(),
|
||||
scripting_tool_use.input.clone(),
|
||||
messages.clone(),
|
||||
ToolType::ScriptingTool,
|
||||
);
|
||||
}
|
||||
|
||||
pending_tool_uses
|
||||
|
@ -1251,17 +1284,49 @@ impl Thread {
|
|||
.chain(pending_scripting_tool_uses)
|
||||
}
|
||||
|
||||
pub fn insert_tool_output(
|
||||
pub fn run_tool(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
ui_text: SharedString,
|
||||
output: Task<Result<String>>,
|
||||
cx: &mut Context<Self>,
|
||||
ui_text: impl Into<SharedString>,
|
||||
input: serde_json::Value,
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
tool_type: ToolType,
|
||||
cx: &mut Context<'_, Thread>,
|
||||
) {
|
||||
let insert_output_task = cx.spawn({
|
||||
let tool_use_id = tool_use_id.clone();
|
||||
async move |thread, cx| {
|
||||
let output = output.await;
|
||||
match tool_type {
|
||||
ToolType::ScriptingTool => {
|
||||
let task = self.spawn_scripting_tool_use(tool_use_id.clone(), input, cx);
|
||||
self.scripting_tool_use
|
||||
.run_pending_tool(tool_use_id, ui_text.into(), task);
|
||||
}
|
||||
ToolType::NonScriptingTool(tool) => {
|
||||
let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
|
||||
self.tool_use
|
||||
.run_pending_tool(tool_use_id, ui_text.into(), task);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_tool_use(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
input: serde_json::Value,
|
||||
tool: Arc<dyn Tool>,
|
||||
cx: &mut Context<Thread>,
|
||||
) -> Task<()> {
|
||||
let run_tool = tool.run(
|
||||
input,
|
||||
messages,
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn({
|
||||
async move |thread: WeakEntity<Thread>, cx| {
|
||||
let output = run_tool.await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
let pending_tool_use = thread
|
||||
|
@ -1276,23 +1341,46 @@ impl Thread {
|
|||
})
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
|
||||
self.tool_use
|
||||
.run_pending_tool(tool_use_id, ui_text, insert_output_task);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn insert_scripting_tool_output(
|
||||
fn spawn_scripting_tool_use(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
ui_text: SharedString,
|
||||
output: Task<Result<String>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let insert_output_task = cx.spawn({
|
||||
input: serde_json::Value,
|
||||
cx: &mut Context<Thread>,
|
||||
) -> Task<()> {
|
||||
let task = match ScriptingTool::deserialize_input(input) {
|
||||
Err(err) => Task::ready(Err(err.into())),
|
||||
Ok(input) => {
|
||||
let (script_id, script_task) =
|
||||
self.scripting_session.update(cx, move |session, cx| {
|
||||
session.run_script(input.lua_script, cx)
|
||||
});
|
||||
|
||||
let session = self.scripting_session.clone();
|
||||
cx.spawn(async move |_, cx| {
|
||||
script_task.await;
|
||||
|
||||
let message = session.read_with(cx, |session, _cx| {
|
||||
// Using a id to get the script output seems impractical.
|
||||
// Why not just include it in the Task result?
|
||||
// This is because we'll later report the script state as it runs,
|
||||
session
|
||||
.get(script_id)
|
||||
.output_message_for_llm()
|
||||
.expect("Script shouldn't still be running")
|
||||
})?;
|
||||
|
||||
Ok(message)
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
cx.spawn({
|
||||
let tool_use_id = tool_use_id.clone();
|
||||
async move |thread, cx| {
|
||||
let output = output.await;
|
||||
let output = task.await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
let pending_tool_use = thread
|
||||
|
@ -1307,10 +1395,7 @@ impl Thread {
|
|||
})
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
|
||||
self.scripting_tool_use
|
||||
.run_pending_tool(tool_use_id, ui_text, insert_output_task);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn attach_tool_results(
|
||||
|
@ -1568,6 +1653,30 @@ impl Thread {
|
|||
pub fn cumulative_token_usage(&self) -> TokenUsage {
|
||||
self.cumulative_token_usage.clone()
|
||||
}
|
||||
|
||||
pub fn deny_tool_use(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_type: ToolType,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let err = Err(anyhow::anyhow!(
|
||||
"Permission to run tool action denied by user"
|
||||
));
|
||||
|
||||
if let ToolType::ScriptingTool = tool_type {
|
||||
self.scripting_tool_use
|
||||
.insert_tool_output(tool_use_id.clone(), err);
|
||||
} else {
|
||||
self.tool_use.insert_tool_output(tool_use_id.clone(), err);
|
||||
}
|
||||
|
||||
cx.emit(ThreadEvent::ToolFinished {
|
||||
tool_use_id,
|
||||
pending_tool_use: None,
|
||||
canceled: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue