assistant2: Decouple scripting tool from the Tool
trait (#26382)
This PR decouples the scripting tool from the `Tool` trait while still allowing it to be used as a tool from the model's perspective. This will allow us to evolve the scripting tool as more of a first-class citizen while still retaining the ability to have the model call it as a regular tool. Release Notes: - N/A
This commit is contained in:
parent
2fc4dec58f
commit
e513e81046
10 changed files with 138 additions and 52 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -490,6 +490,7 @@ dependencies = [
|
||||||
"proto",
|
"proto",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"rope",
|
"rope",
|
||||||
|
"scripting_tool",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"settings",
|
"settings",
|
||||||
|
@ -11915,7 +11916,6 @@ name = "scripting_tool"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assistant_tool",
|
|
||||||
"collections",
|
"collections",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
|
@ -16986,7 +16986,6 @@ dependencies = [
|
||||||
"repl",
|
"repl",
|
||||||
"reqwest_client",
|
"reqwest_client",
|
||||||
"rope",
|
"rope",
|
||||||
"scripting_tool",
|
|
||||||
"search",
|
"search",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -8,7 +8,6 @@ members = [
|
||||||
"crates/assistant",
|
"crates/assistant",
|
||||||
"crates/assistant2",
|
"crates/assistant2",
|
||||||
"crates/assistant_context_editor",
|
"crates/assistant_context_editor",
|
||||||
"crates/scripting_tool",
|
|
||||||
"crates/assistant_settings",
|
"crates/assistant_settings",
|
||||||
"crates/assistant_slash_command",
|
"crates/assistant_slash_command",
|
||||||
"crates/assistant_slash_commands",
|
"crates/assistant_slash_commands",
|
||||||
|
@ -119,6 +118,7 @@ members = [
|
||||||
"crates/rope",
|
"crates/rope",
|
||||||
"crates/rpc",
|
"crates/rpc",
|
||||||
"crates/schema_generator",
|
"crates/schema_generator",
|
||||||
|
"crates/scripting_tool",
|
||||||
"crates/search",
|
"crates/search",
|
||||||
"crates/semantic_index",
|
"crates/semantic_index",
|
||||||
"crates/semantic_version",
|
"crates/semantic_version",
|
||||||
|
|
|
@ -59,6 +59,7 @@ prompt_library.workspace = true
|
||||||
prompt_store.workspace = true
|
prompt_store.workspace = true
|
||||||
proto.workspace = true
|
proto.workspace = true
|
||||||
rope.workspace = true
|
rope.workspace = true
|
||||||
|
scripting_tool.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
settings.workspace = true
|
settings.workspace = true
|
||||||
|
|
|
@ -457,9 +457,13 @@ impl ActiveThread {
|
||||||
|
|
||||||
let context = thread.context_for_message(message_id);
|
let context = thread.context_for_message(message_id);
|
||||||
let tool_uses = thread.tool_uses_for_message(message_id);
|
let tool_uses = thread.tool_uses_for_message(message_id);
|
||||||
|
let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
|
||||||
|
|
||||||
// Don't render user messages that are just there for returning tool results.
|
// Don't render user messages that are just there for returning tool results.
|
||||||
if message.role == Role::User && thread.message_has_tool_results(message_id) {
|
if message.role == Role::User
|
||||||
|
&& (thread.message_has_tool_results(message_id)
|
||||||
|
|| thread.message_has_scripting_tool_results(message_id))
|
||||||
|
{
|
||||||
return Empty.into_any();
|
return Empty.into_any();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -609,16 +613,22 @@ impl ActiveThread {
|
||||||
.id(("message-container", ix))
|
.id(("message-container", ix))
|
||||||
.child(message_content)
|
.child(message_content)
|
||||||
.map(|parent| {
|
.map(|parent| {
|
||||||
if tool_uses.is_empty() {
|
if tool_uses.is_empty() && scripting_tool_uses.is_empty() {
|
||||||
return parent;
|
return parent;
|
||||||
}
|
}
|
||||||
|
|
||||||
parent.child(
|
parent.child(
|
||||||
v_flex().children(
|
v_flex()
|
||||||
tool_uses
|
.children(
|
||||||
.into_iter()
|
tool_uses
|
||||||
.map(|tool_use| self.render_tool_use(tool_use, cx)),
|
.into_iter()
|
||||||
),
|
.map(|tool_use| self.render_tool_use(tool_use, cx)),
|
||||||
|
)
|
||||||
|
.children(
|
||||||
|
scripting_tool_uses
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool_use| self.render_scripting_tool_use(tool_use, cx)),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
}),
|
}),
|
||||||
Role::System => div().id(("message-container", ix)).py_1().px_2().child(
|
Role::System => div().id(("message-container", ix)).py_1().px_2().child(
|
||||||
|
@ -727,6 +737,15 @@ impl ActiveThread {
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn render_scripting_tool_use(
|
||||||
|
&self,
|
||||||
|
tool_use: ToolUse,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> impl IntoElement {
|
||||||
|
// TODO: Add custom rendering for scripting tool uses.
|
||||||
|
self.render_tool_use(tool_use, cx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Render for ActiveThread {
|
impl Render for ActiveThread {
|
||||||
|
|
|
@ -13,13 +13,14 @@ use language_model::{
|
||||||
Role, StopReason,
|
Role, StopReason,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
|
use scripting_tool::ScriptingTool;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use util::{post_inc, TryFutureExt as _};
|
use util::{post_inc, TryFutureExt as _};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
|
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
|
||||||
use crate::thread_store::SavedThread;
|
use crate::thread_store::SavedThread;
|
||||||
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
|
use crate::tool_use::{ToolUse, ToolUseState};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub enum RequestKind {
|
pub enum RequestKind {
|
||||||
|
@ -75,6 +76,7 @@ pub struct Thread {
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
tool_use: ToolUseState,
|
tool_use: ToolUseState,
|
||||||
|
scripting_tool_use: ToolUseState,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Thread {
|
impl Thread {
|
||||||
|
@ -97,6 +99,7 @@ impl Thread {
|
||||||
project,
|
project,
|
||||||
tools,
|
tools,
|
||||||
tool_use: ToolUseState::new(),
|
tool_use: ToolUseState::new(),
|
||||||
|
scripting_tool_use: ToolUseState::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,6 +118,7 @@ impl Thread {
|
||||||
.unwrap_or(0),
|
.unwrap_or(0),
|
||||||
);
|
);
|
||||||
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
||||||
|
let scripting_tool_use = ToolUseState::new();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
id,
|
id,
|
||||||
|
@ -138,6 +142,7 @@ impl Thread {
|
||||||
project,
|
project,
|
||||||
tools,
|
tools,
|
||||||
tool_use,
|
tool_use,
|
||||||
|
scripting_tool_use,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,31 +203,46 @@ impl Thread {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
|
|
||||||
self.tool_use.pending_tool_uses()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns whether all of the tool uses have finished running.
|
/// Returns whether all of the tool uses have finished running.
|
||||||
pub fn all_tools_finished(&self) -> bool {
|
pub fn all_tools_finished(&self) -> bool {
|
||||||
|
let mut all_pending_tool_uses = self
|
||||||
|
.tool_use
|
||||||
|
.pending_tool_uses()
|
||||||
|
.into_iter()
|
||||||
|
.chain(self.scripting_tool_use.pending_tool_uses());
|
||||||
|
|
||||||
// If the only pending tool uses left are the ones with errors, then that means that we've finished running all
|
// If the only pending tool uses left are the ones with errors, then that means that we've finished running all
|
||||||
// of the pending tools.
|
// of the pending tools.
|
||||||
self.pending_tool_uses()
|
all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
|
||||||
.into_iter()
|
|
||||||
.all(|tool_use| tool_use.status.is_error())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
|
pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
|
||||||
self.tool_use.tool_uses_for_message(id)
|
self.tool_use.tool_uses_for_message(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
|
||||||
|
self.scripting_tool_use.tool_uses_for_message(id)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
|
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||||
self.tool_use.tool_results_for_message(id)
|
self.tool_use.tool_results_for_message(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn scripting_tool_results_for_message(
|
||||||
|
&self,
|
||||||
|
id: MessageId,
|
||||||
|
) -> Vec<&LanguageModelToolResult> {
|
||||||
|
self.scripting_tool_use.tool_results_for_message(id)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||||
self.tool_use.message_has_tool_results(message_id)
|
self.tool_use.message_has_tool_results(message_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
|
||||||
|
self.scripting_tool_use.message_has_tool_results(message_id)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn insert_user_message(
|
pub fn insert_user_message(
|
||||||
&mut self,
|
&mut self,
|
||||||
text: impl Into<String>,
|
text: impl Into<String>,
|
||||||
|
@ -313,16 +333,25 @@ impl Thread {
|
||||||
let mut request = self.to_completion_request(request_kind, cx);
|
let mut request = self.to_completion_request(request_kind, cx);
|
||||||
|
|
||||||
if use_tools {
|
if use_tools {
|
||||||
request.tools = self
|
let mut tools = Vec::new();
|
||||||
.tools()
|
tools.push(LanguageModelRequestTool {
|
||||||
.tools(cx)
|
name: ScriptingTool::NAME.into(),
|
||||||
.into_iter()
|
description: ScriptingTool::DESCRIPTION.into(),
|
||||||
.map(|tool| LanguageModelRequestTool {
|
input_schema: ScriptingTool::input_schema(),
|
||||||
name: tool.name(),
|
});
|
||||||
description: tool.description(),
|
|
||||||
input_schema: tool.input_schema(),
|
tools.extend(
|
||||||
})
|
self.tools()
|
||||||
.collect();
|
.tools(cx)
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool| LanguageModelRequestTool {
|
||||||
|
name: tool.name(),
|
||||||
|
description: tool.description(),
|
||||||
|
input_schema: tool.input_schema(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
request.tools = tools;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.stream_completion(request, model, cx);
|
self.stream_completion(request, model, cx);
|
||||||
|
@ -357,6 +386,8 @@ impl Thread {
|
||||||
RequestKind::Chat => {
|
RequestKind::Chat => {
|
||||||
self.tool_use
|
self.tool_use
|
||||||
.attach_tool_results(message.id, &mut request_message);
|
.attach_tool_results(message.id, &mut request_message);
|
||||||
|
self.scripting_tool_use
|
||||||
|
.attach_tool_results(message.id, &mut request_message);
|
||||||
}
|
}
|
||||||
RequestKind::Summarize => {
|
RequestKind::Summarize => {
|
||||||
// We don't care about tool use during summarization.
|
// We don't care about tool use during summarization.
|
||||||
|
@ -373,6 +404,8 @@ impl Thread {
|
||||||
RequestKind::Chat => {
|
RequestKind::Chat => {
|
||||||
self.tool_use
|
self.tool_use
|
||||||
.attach_tool_uses(message.id, &mut request_message);
|
.attach_tool_uses(message.id, &mut request_message);
|
||||||
|
self.scripting_tool_use
|
||||||
|
.attach_tool_uses(message.id, &mut request_message);
|
||||||
}
|
}
|
||||||
RequestKind::Summarize => {
|
RequestKind::Summarize => {
|
||||||
// We don't care about tool use during summarization.
|
// We don't care about tool use during summarization.
|
||||||
|
@ -450,9 +483,15 @@ impl Thread {
|
||||||
.iter()
|
.iter()
|
||||||
.rfind(|message| message.role == Role::Assistant)
|
.rfind(|message| message.role == Role::Assistant)
|
||||||
{
|
{
|
||||||
thread
|
if tool_use.name.as_ref() == ScriptingTool::NAME {
|
||||||
.tool_use
|
thread
|
||||||
.request_tool_use(last_assistant_message.id, tool_use);
|
.scripting_tool_use
|
||||||
|
.request_tool_use(last_assistant_message.id, tool_use);
|
||||||
|
} else {
|
||||||
|
thread
|
||||||
|
.tool_use
|
||||||
|
.request_tool_use(last_assistant_message.id, tool_use);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -572,6 +611,7 @@ impl Thread {
|
||||||
|
|
||||||
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
|
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
|
||||||
let pending_tool_uses = self
|
let pending_tool_uses = self
|
||||||
|
.tool_use
|
||||||
.pending_tool_uses()
|
.pending_tool_uses()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|tool_use| tool_use.status.is_idle())
|
.filter(|tool_use| tool_use.status.is_idle())
|
||||||
|
@ -585,6 +625,20 @@ impl Thread {
|
||||||
self.insert_tool_output(tool_use.id.clone(), task, cx);
|
self.insert_tool_output(tool_use.id.clone(), task, cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let pending_scripting_tool_uses = self
|
||||||
|
.scripting_tool_use
|
||||||
|
.pending_tool_uses()
|
||||||
|
.into_iter()
|
||||||
|
.filter(|tool_use| tool_use.status.is_idle())
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
for scripting_tool_use in pending_scripting_tool_uses {
|
||||||
|
let task = ScriptingTool.run(scripting_tool_use.input, self.project.clone(), cx);
|
||||||
|
|
||||||
|
self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert_tool_output(
|
pub fn insert_tool_output(
|
||||||
|
@ -613,6 +667,32 @@ impl Thread {
|
||||||
.run_pending_tool(tool_use_id, insert_output_task);
|
.run_pending_tool(tool_use_id, insert_output_task);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn insert_scripting_tool_output(
|
||||||
|
&mut self,
|
||||||
|
tool_use_id: LanguageModelToolUseId,
|
||||||
|
output: Task<Result<String>>,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let insert_output_task = cx.spawn(|thread, mut cx| {
|
||||||
|
let tool_use_id = tool_use_id.clone();
|
||||||
|
async move {
|
||||||
|
let output = output.await;
|
||||||
|
thread
|
||||||
|
.update(&mut cx, |thread, cx| {
|
||||||
|
thread
|
||||||
|
.scripting_tool_use
|
||||||
|
.insert_tool_output(tool_use_id.clone(), output);
|
||||||
|
|
||||||
|
cx.emit(ThreadEvent::ToolFinished { tool_use_id });
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
self.scripting_tool_use
|
||||||
|
.run_pending_tool(tool_use_id, insert_output_task);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn send_tool_results_to_model(
|
pub fn send_tool_results_to_model(
|
||||||
&mut self,
|
&mut self,
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
|
|
|
@ -267,6 +267,7 @@ impl ToolUseState {
|
||||||
pub struct PendingToolUse {
|
pub struct PendingToolUse {
|
||||||
pub id: LanguageModelToolUseId,
|
pub id: LanguageModelToolUseId,
|
||||||
/// The ID of the Assistant message in which the tool use was requested.
|
/// The ID of the Assistant message in which the tool use was requested.
|
||||||
|
#[allow(unused)]
|
||||||
pub assistant_message_id: MessageId,
|
pub assistant_message_id: MessageId,
|
||||||
pub name: Arc<str>,
|
pub name: Arc<str>,
|
||||||
pub input: serde_json::Value,
|
pub input: serde_json::Value,
|
||||||
|
|
|
@ -14,7 +14,6 @@ doctest = false
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
assistant_tool.workspace = true
|
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
gpui.workspace = true
|
gpui.workspace = true
|
||||||
|
|
|
@ -3,40 +3,29 @@ mod session;
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use session::*;
|
use session::*;
|
||||||
|
|
||||||
use assistant_tool::{Tool, ToolRegistry};
|
|
||||||
use gpui::{App, AppContext as _, Entity, Task};
|
use gpui::{App, AppContext as _, Entity, Task};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
pub fn init(cx: &App) {
|
|
||||||
let registry = ToolRegistry::global(cx);
|
|
||||||
registry.register_tool(ScriptingTool);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, JsonSchema)]
|
#[derive(Debug, Deserialize, JsonSchema)]
|
||||||
struct ScriptingToolInput {
|
struct ScriptingToolInput {
|
||||||
lua_script: String,
|
lua_script: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ScriptingTool;
|
pub struct ScriptingTool;
|
||||||
|
|
||||||
impl Tool for ScriptingTool {
|
impl ScriptingTool {
|
||||||
fn name(&self) -> String {
|
pub const NAME: &str = "lua-interpreter";
|
||||||
"lua-interpreter".into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> String {
|
pub const DESCRIPTION: &str = include_str!("scripting_tool_description.txt");
|
||||||
include_str!("scripting_tool_description.txt").into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn input_schema(&self) -> serde_json::Value {
|
pub fn input_schema() -> serde_json::Value {
|
||||||
let schema = schemars::schema_for!(ScriptingToolInput);
|
let schema = schemars::schema_for!(ScriptingToolInput);
|
||||||
serde_json::to_value(&schema).unwrap()
|
serde_json::to_value(&schema).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(
|
pub fn run(
|
||||||
self: Arc<Self>,
|
&self,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
project: Entity<Project>,
|
project: Entity<Project>,
|
||||||
cx: &mut App,
|
cx: &mut App,
|
||||||
|
|
|
@ -98,7 +98,6 @@ remote.workspace = true
|
||||||
repl.workspace = true
|
repl.workspace = true
|
||||||
reqwest_client.workspace = true
|
reqwest_client.workspace = true
|
||||||
rope.workspace = true
|
rope.workspace = true
|
||||||
scripting_tool.workspace = true
|
|
||||||
search.workspace = true
|
search.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
|
|
@ -476,7 +476,6 @@ fn main() {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
assistant_tools::init(cx);
|
assistant_tools::init(cx);
|
||||||
scripting_tool::init(cx);
|
|
||||||
repl::init(app_state.fs.clone(), cx);
|
repl::init(app_state.fs.clone(), cx);
|
||||||
extension_host::init(
|
extension_host::init(
|
||||||
extension_host_proxy,
|
extension_host_proxy,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue