assistant2: Decouple scripting tool from the Tool trait (#26382)

This PR decouples the scripting tool from the `Tool` trait while still
allowing it to be used as a tool from the model's perspective.

This will allow us to evolve the scripting tool as more of a first-class
citizen while still retaining the ability to have the model call it as a
regular tool.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-03-10 13:57:03 -04:00 committed by GitHub
parent 2fc4dec58f
commit e513e81046
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 138 additions and 52 deletions

3
Cargo.lock generated
View file

@ -490,6 +490,7 @@ dependencies = [
"proto", "proto",
"rand 0.8.5", "rand 0.8.5",
"rope", "rope",
"scripting_tool",
"serde", "serde",
"serde_json", "serde_json",
"settings", "settings",
@ -11915,7 +11916,6 @@ name = "scripting_tool"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assistant_tool",
"collections", "collections",
"futures 0.3.31", "futures 0.3.31",
"gpui", "gpui",
@ -16986,7 +16986,6 @@ dependencies = [
"repl", "repl",
"reqwest_client", "reqwest_client",
"rope", "rope",
"scripting_tool",
"search", "search",
"serde", "serde",
"serde_json", "serde_json",

View file

@ -8,7 +8,6 @@ members = [
"crates/assistant", "crates/assistant",
"crates/assistant2", "crates/assistant2",
"crates/assistant_context_editor", "crates/assistant_context_editor",
"crates/scripting_tool",
"crates/assistant_settings", "crates/assistant_settings",
"crates/assistant_slash_command", "crates/assistant_slash_command",
"crates/assistant_slash_commands", "crates/assistant_slash_commands",
@ -119,6 +118,7 @@ members = [
"crates/rope", "crates/rope",
"crates/rpc", "crates/rpc",
"crates/schema_generator", "crates/schema_generator",
"crates/scripting_tool",
"crates/search", "crates/search",
"crates/semantic_index", "crates/semantic_index",
"crates/semantic_version", "crates/semantic_version",

View file

@ -59,6 +59,7 @@ prompt_library.workspace = true
prompt_store.workspace = true prompt_store.workspace = true
proto.workspace = true proto.workspace = true
rope.workspace = true rope.workspace = true
scripting_tool.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
settings.workspace = true settings.workspace = true

View file

@ -457,9 +457,13 @@ impl ActiveThread {
let context = thread.context_for_message(message_id); let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id); let tool_uses = thread.tool_uses_for_message(message_id);
let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
// Don't render user messages that are just there for returning tool results. // Don't render user messages that are just there for returning tool results.
if message.role == Role::User && thread.message_has_tool_results(message_id) { if message.role == Role::User
&& (thread.message_has_tool_results(message_id)
|| thread.message_has_scripting_tool_results(message_id))
{
return Empty.into_any(); return Empty.into_any();
} }
@ -609,16 +613,22 @@ impl ActiveThread {
.id(("message-container", ix)) .id(("message-container", ix))
.child(message_content) .child(message_content)
.map(|parent| { .map(|parent| {
if tool_uses.is_empty() { if tool_uses.is_empty() && scripting_tool_uses.is_empty() {
return parent; return parent;
} }
parent.child( parent.child(
v_flex().children( v_flex()
tool_uses .children(
.into_iter() tool_uses
.map(|tool_use| self.render_tool_use(tool_use, cx)), .into_iter()
), .map(|tool_use| self.render_tool_use(tool_use, cx)),
)
.children(
scripting_tool_uses
.into_iter()
.map(|tool_use| self.render_scripting_tool_use(tool_use, cx)),
),
) )
}), }),
Role::System => div().id(("message-container", ix)).py_1().px_2().child( Role::System => div().id(("message-container", ix)).py_1().px_2().child(
@ -727,6 +737,15 @@ impl ActiveThread {
}), }),
) )
} }
fn render_scripting_tool_use(
&self,
tool_use: ToolUse,
cx: &mut Context<Self>,
) -> impl IntoElement {
// TODO: Add custom rendering for scripting tool uses.
self.render_tool_use(tool_use, cx)
}
} }
impl Render for ActiveThread { impl Render for ActiveThread {

View file

@ -13,13 +13,14 @@ use language_model::{
Role, StopReason, Role, StopReason,
}; };
use project::Project; use project::Project;
use scripting_tool::ScriptingTool;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _}; use util::{post_inc, TryFutureExt as _};
use uuid::Uuid; use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
use crate::thread_store::SavedThread; use crate::thread_store::SavedThread;
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState}; use crate::tool_use::{ToolUse, ToolUseState};
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum RequestKind { pub enum RequestKind {
@ -75,6 +76,7 @@ pub struct Thread {
project: Entity<Project>, project: Entity<Project>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState, tool_use: ToolUseState,
scripting_tool_use: ToolUseState,
} }
impl Thread { impl Thread {
@ -97,6 +99,7 @@ impl Thread {
project, project,
tools, tools,
tool_use: ToolUseState::new(), tool_use: ToolUseState::new(),
scripting_tool_use: ToolUseState::new(),
} }
} }
@ -115,6 +118,7 @@ impl Thread {
.unwrap_or(0), .unwrap_or(0),
); );
let tool_use = ToolUseState::from_saved_messages(&saved.messages); let tool_use = ToolUseState::from_saved_messages(&saved.messages);
let scripting_tool_use = ToolUseState::new();
Self { Self {
id, id,
@ -138,6 +142,7 @@ impl Thread {
project, project,
tools, tools,
tool_use, tool_use,
scripting_tool_use,
} }
} }
@ -198,31 +203,46 @@ impl Thread {
) )
} }
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.tool_use.pending_tool_uses()
}
/// Returns whether all of the tool uses have finished running. /// Returns whether all of the tool uses have finished running.
pub fn all_tools_finished(&self) -> bool { pub fn all_tools_finished(&self) -> bool {
let mut all_pending_tool_uses = self
.tool_use
.pending_tool_uses()
.into_iter()
.chain(self.scripting_tool_use.pending_tool_uses());
// If the only pending tool uses left are the ones with errors, then that means that we've finished running all // If the only pending tool uses left are the ones with errors, then that means that we've finished running all
// of the pending tools. // of the pending tools.
self.pending_tool_uses() all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
.into_iter()
.all(|tool_use| tool_use.status.is_error())
} }
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> { pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
self.tool_use.tool_uses_for_message(id) self.tool_use.tool_uses_for_message(id)
} }
pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
self.scripting_tool_use.tool_uses_for_message(id)
}
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> { pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
self.tool_use.tool_results_for_message(id) self.tool_use.tool_results_for_message(id)
} }
pub fn scripting_tool_results_for_message(
&self,
id: MessageId,
) -> Vec<&LanguageModelToolResult> {
self.scripting_tool_use.tool_results_for_message(id)
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id) self.tool_use.message_has_tool_results(message_id)
} }
pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
self.scripting_tool_use.message_has_tool_results(message_id)
}
pub fn insert_user_message( pub fn insert_user_message(
&mut self, &mut self,
text: impl Into<String>, text: impl Into<String>,
@ -313,16 +333,25 @@ impl Thread {
let mut request = self.to_completion_request(request_kind, cx); let mut request = self.to_completion_request(request_kind, cx);
if use_tools { if use_tools {
request.tools = self let mut tools = Vec::new();
.tools() tools.push(LanguageModelRequestTool {
.tools(cx) name: ScriptingTool::NAME.into(),
.into_iter() description: ScriptingTool::DESCRIPTION.into(),
.map(|tool| LanguageModelRequestTool { input_schema: ScriptingTool::input_schema(),
name: tool.name(), });
description: tool.description(),
input_schema: tool.input_schema(), tools.extend(
}) self.tools()
.collect(); .tools(cx)
.into_iter()
.map(|tool| LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool.input_schema(),
}),
);
request.tools = tools;
} }
self.stream_completion(request, model, cx); self.stream_completion(request, model, cx);
@ -357,6 +386,8 @@ impl Thread {
RequestKind::Chat => { RequestKind::Chat => {
self.tool_use self.tool_use
.attach_tool_results(message.id, &mut request_message); .attach_tool_results(message.id, &mut request_message);
self.scripting_tool_use
.attach_tool_results(message.id, &mut request_message);
} }
RequestKind::Summarize => { RequestKind::Summarize => {
// We don't care about tool use during summarization. // We don't care about tool use during summarization.
@ -373,6 +404,8 @@ impl Thread {
RequestKind::Chat => { RequestKind::Chat => {
self.tool_use self.tool_use
.attach_tool_uses(message.id, &mut request_message); .attach_tool_uses(message.id, &mut request_message);
self.scripting_tool_use
.attach_tool_uses(message.id, &mut request_message);
} }
RequestKind::Summarize => { RequestKind::Summarize => {
// We don't care about tool use during summarization. // We don't care about tool use during summarization.
@ -450,9 +483,15 @@ impl Thread {
.iter() .iter()
.rfind(|message| message.role == Role::Assistant) .rfind(|message| message.role == Role::Assistant)
{ {
thread if tool_use.name.as_ref() == ScriptingTool::NAME {
.tool_use thread
.request_tool_use(last_assistant_message.id, tool_use); .scripting_tool_use
.request_tool_use(last_assistant_message.id, tool_use);
} else {
thread
.tool_use
.request_tool_use(last_assistant_message.id, tool_use);
}
} }
} }
} }
@ -572,6 +611,7 @@ impl Thread {
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) { pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
let pending_tool_uses = self let pending_tool_uses = self
.tool_use
.pending_tool_uses() .pending_tool_uses()
.into_iter() .into_iter()
.filter(|tool_use| tool_use.status.is_idle()) .filter(|tool_use| tool_use.status.is_idle())
@ -585,6 +625,20 @@ impl Thread {
self.insert_tool_output(tool_use.id.clone(), task, cx); self.insert_tool_output(tool_use.id.clone(), task, cx);
} }
} }
let pending_scripting_tool_uses = self
.scripting_tool_use
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for scripting_tool_use in pending_scripting_tool_uses {
let task = ScriptingTool.run(scripting_tool_use.input, self.project.clone(), cx);
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
}
} }
pub fn insert_tool_output( pub fn insert_tool_output(
@ -613,6 +667,32 @@ impl Thread {
.run_pending_tool(tool_use_id, insert_output_task); .run_pending_tool(tool_use_id, insert_output_task);
} }
pub fn insert_scripting_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
output: Task<Result<String>>,
cx: &mut Context<Self>,
) {
let insert_output_task = cx.spawn(|thread, mut cx| {
let tool_use_id = tool_use_id.clone();
async move {
let output = output.await;
thread
.update(&mut cx, |thread, cx| {
thread
.scripting_tool_use
.insert_tool_output(tool_use_id.clone(), output);
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
})
.ok();
}
});
self.scripting_tool_use
.run_pending_tool(tool_use_id, insert_output_task);
}
pub fn send_tool_results_to_model( pub fn send_tool_results_to_model(
&mut self, &mut self,
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,

View file

@ -267,6 +267,7 @@ impl ToolUseState {
pub struct PendingToolUse { pub struct PendingToolUse {
pub id: LanguageModelToolUseId, pub id: LanguageModelToolUseId,
/// The ID of the Assistant message in which the tool use was requested. /// The ID of the Assistant message in which the tool use was requested.
#[allow(unused)]
pub assistant_message_id: MessageId, pub assistant_message_id: MessageId,
pub name: Arc<str>, pub name: Arc<str>,
pub input: serde_json::Value, pub input: serde_json::Value,

View file

@ -14,7 +14,6 @@ doctest = false
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true collections.workspace = true
futures.workspace = true futures.workspace = true
gpui.workspace = true gpui.workspace = true

View file

@ -3,40 +3,29 @@ mod session;
use project::Project; use project::Project;
use session::*; use session::*;
use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Entity, Task}; use gpui::{App, AppContext as _, Entity, Task};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc;
pub fn init(cx: &App) {
let registry = ToolRegistry::global(cx);
registry.register_tool(ScriptingTool);
}
#[derive(Debug, Deserialize, JsonSchema)] #[derive(Debug, Deserialize, JsonSchema)]
struct ScriptingToolInput { struct ScriptingToolInput {
lua_script: String, lua_script: String,
} }
struct ScriptingTool; pub struct ScriptingTool;
impl Tool for ScriptingTool { impl ScriptingTool {
fn name(&self) -> String { pub const NAME: &str = "lua-interpreter";
"lua-interpreter".into()
}
fn description(&self) -> String { pub const DESCRIPTION: &str = include_str!("scripting_tool_description.txt");
include_str!("scripting_tool_description.txt").into()
}
fn input_schema(&self) -> serde_json::Value { pub fn input_schema() -> serde_json::Value {
let schema = schemars::schema_for!(ScriptingToolInput); let schema = schemars::schema_for!(ScriptingToolInput);
serde_json::to_value(&schema).unwrap() serde_json::to_value(&schema).unwrap()
} }
fn run( pub fn run(
self: Arc<Self>, &self,
input: serde_json::Value, input: serde_json::Value,
project: Entity<Project>, project: Entity<Project>,
cx: &mut App, cx: &mut App,

View file

@ -98,7 +98,6 @@ remote.workspace = true
repl.workspace = true repl.workspace = true
reqwest_client.workspace = true reqwest_client.workspace = true
rope.workspace = true rope.workspace = true
scripting_tool.workspace = true
search.workspace = true search.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true

View file

@ -476,7 +476,6 @@ fn main() {
cx, cx,
); );
assistant_tools::init(cx); assistant_tools::init(cx);
scripting_tool::init(cx);
repl::init(app_state.fs.clone(), cx); repl::init(app_state.fs.clone(), cx);
extension_host::init( extension_host::init(
extension_host_proxy, extension_host_proxy,