diff --git a/Cargo.lock b/Cargo.lock index 56850e38c1..da7a2b126a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -490,6 +490,7 @@ dependencies = [ "proto", "rand 0.8.5", "rope", + "scripting_tool", "serde", "serde_json", "settings", @@ -11915,7 +11916,6 @@ name = "scripting_tool" version = "0.1.0" dependencies = [ "anyhow", - "assistant_tool", "collections", "futures 0.3.31", "gpui", @@ -16986,7 +16986,6 @@ dependencies = [ "repl", "reqwest_client", "rope", - "scripting_tool", "search", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index ddcd16eaca..d85c7ef826 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,6 @@ members = [ "crates/assistant", "crates/assistant2", "crates/assistant_context_editor", - "crates/scripting_tool", "crates/assistant_settings", "crates/assistant_slash_command", "crates/assistant_slash_commands", @@ -119,6 +118,7 @@ members = [ "crates/rope", "crates/rpc", "crates/schema_generator", + "crates/scripting_tool", "crates/search", "crates/semantic_index", "crates/semantic_version", diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 66aca29947..1b4210b74f 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -59,6 +59,7 @@ prompt_library.workspace = true prompt_store.workspace = true proto.workspace = true rope.workspace = true +scripting_tool.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index e63ae4cd05..0e8471a4a2 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -457,9 +457,13 @@ impl ActiveThread { let context = thread.context_for_message(message_id); let tool_uses = thread.tool_uses_for_message(message_id); + let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id); // Don't render user messages that are just there for returning tool results. - if message.role == Role::User && thread.message_has_tool_results(message_id) { + if message.role == Role::User + && (thread.message_has_tool_results(message_id) + || thread.message_has_scripting_tool_results(message_id)) + { return Empty.into_any(); } @@ -609,16 +613,22 @@ impl ActiveThread { .id(("message-container", ix)) .child(message_content) .map(|parent| { - if tool_uses.is_empty() { + if tool_uses.is_empty() && scripting_tool_uses.is_empty() { return parent; } parent.child( - v_flex().children( - tool_uses - .into_iter() - .map(|tool_use| self.render_tool_use(tool_use, cx)), - ), + v_flex() + .children( + tool_uses + .into_iter() + .map(|tool_use| self.render_tool_use(tool_use, cx)), + ) + .children( + scripting_tool_uses + .into_iter() + .map(|tool_use| self.render_scripting_tool_use(tool_use, cx)), + ), ) }), Role::System => div().id(("message-container", ix)).py_1().px_2().child( @@ -727,6 +737,15 @@ impl ActiveThread { }), ) } + + fn render_scripting_tool_use( + &self, + tool_use: ToolUse, + cx: &mut Context, + ) -> impl IntoElement { + // TODO: Add custom rendering for scripting tool uses. + self.render_tool_use(tool_use, cx) + } } impl Render for ActiveThread { diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 9d19a59d0e..d4b01b9878 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -13,13 +13,14 @@ use language_model::{ Role, StopReason, }; use project::Project; +use scripting_tool::ScriptingTool; use serde::{Deserialize, Serialize}; use util::{post_inc, TryFutureExt as _}; use uuid::Uuid; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; use crate::thread_store::SavedThread; -use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState}; +use crate::tool_use::{ToolUse, ToolUseState}; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -75,6 +76,7 @@ pub struct Thread { project: Entity, tools: Arc, tool_use: ToolUseState, + scripting_tool_use: ToolUseState, } impl Thread { @@ -97,6 +99,7 @@ impl Thread { project, tools, tool_use: ToolUseState::new(), + scripting_tool_use: ToolUseState::new(), } } @@ -115,6 +118,7 @@ impl Thread { .unwrap_or(0), ); let tool_use = ToolUseState::from_saved_messages(&saved.messages); + let scripting_tool_use = ToolUseState::new(); Self { id, @@ -138,6 +142,7 @@ impl Thread { project, tools, tool_use, + scripting_tool_use, } } @@ -198,31 +203,46 @@ impl Thread { ) } - pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { - self.tool_use.pending_tool_uses() - } - /// Returns whether all of the tool uses have finished running. pub fn all_tools_finished(&self) -> bool { + let mut all_pending_tool_uses = self + .tool_use + .pending_tool_uses() + .into_iter() + .chain(self.scripting_tool_use.pending_tool_uses()); + // If the only pending tool uses left are the ones with errors, then that means that we've finished running all // of the pending tools. - self.pending_tool_uses() - .into_iter() - .all(|tool_use| tool_use.status.is_error()) + all_pending_tool_uses.all(|tool_use| tool_use.status.is_error()) } pub fn tool_uses_for_message(&self, id: MessageId) -> Vec { self.tool_use.tool_uses_for_message(id) } + pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec { + self.scripting_tool_use.tool_uses_for_message(id) + } + pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> { self.tool_use.tool_results_for_message(id) } + pub fn scripting_tool_results_for_message( + &self, + id: MessageId, + ) -> Vec<&LanguageModelToolResult> { + self.scripting_tool_use.tool_results_for_message(id) + } + pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { self.tool_use.message_has_tool_results(message_id) } + pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool { + self.scripting_tool_use.message_has_tool_results(message_id) + } + pub fn insert_user_message( &mut self, text: impl Into, @@ -313,16 +333,25 @@ impl Thread { let mut request = self.to_completion_request(request_kind, cx); if use_tools { - request.tools = self - .tools() - .tools(cx) - .into_iter() - .map(|tool| LanguageModelRequestTool { - name: tool.name(), - description: tool.description(), - input_schema: tool.input_schema(), - }) - .collect(); + let mut tools = Vec::new(); + tools.push(LanguageModelRequestTool { + name: ScriptingTool::NAME.into(), + description: ScriptingTool::DESCRIPTION.into(), + input_schema: ScriptingTool::input_schema(), + }); + + tools.extend( + self.tools() + .tools(cx) + .into_iter() + .map(|tool| LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema: tool.input_schema(), + }), + ); + + request.tools = tools; } self.stream_completion(request, model, cx); @@ -357,6 +386,8 @@ impl Thread { RequestKind::Chat => { self.tool_use .attach_tool_results(message.id, &mut request_message); + self.scripting_tool_use + .attach_tool_results(message.id, &mut request_message); } RequestKind::Summarize => { // We don't care about tool use during summarization. @@ -373,6 +404,8 @@ impl Thread { RequestKind::Chat => { self.tool_use .attach_tool_uses(message.id, &mut request_message); + self.scripting_tool_use + .attach_tool_uses(message.id, &mut request_message); } RequestKind::Summarize => { // We don't care about tool use during summarization. @@ -450,9 +483,15 @@ impl Thread { .iter() .rfind(|message| message.role == Role::Assistant) { - thread - .tool_use - .request_tool_use(last_assistant_message.id, tool_use); + if tool_use.name.as_ref() == ScriptingTool::NAME { + thread + .scripting_tool_use + .request_tool_use(last_assistant_message.id, tool_use); + } else { + thread + .tool_use + .request_tool_use(last_assistant_message.id, tool_use); + } } } } @@ -572,6 +611,7 @@ impl Thread { pub fn use_pending_tools(&mut self, cx: &mut Context) { let pending_tool_uses = self + .tool_use .pending_tool_uses() .into_iter() .filter(|tool_use| tool_use.status.is_idle()) @@ -585,6 +625,20 @@ impl Thread { self.insert_tool_output(tool_use.id.clone(), task, cx); } } + + let pending_scripting_tool_uses = self + .scripting_tool_use + .pending_tool_uses() + .into_iter() + .filter(|tool_use| tool_use.status.is_idle()) + .cloned() + .collect::>(); + + for scripting_tool_use in pending_scripting_tool_uses { + let task = ScriptingTool.run(scripting_tool_use.input, self.project.clone(), cx); + + self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx); + } } pub fn insert_tool_output( @@ -613,6 +667,32 @@ impl Thread { .run_pending_tool(tool_use_id, insert_output_task); } + pub fn insert_scripting_tool_output( + &mut self, + tool_use_id: LanguageModelToolUseId, + output: Task>, + cx: &mut Context, + ) { + let insert_output_task = cx.spawn(|thread, mut cx| { + let tool_use_id = tool_use_id.clone(); + async move { + let output = output.await; + thread + .update(&mut cx, |thread, cx| { + thread + .scripting_tool_use + .insert_tool_output(tool_use_id.clone(), output); + + cx.emit(ThreadEvent::ToolFinished { tool_use_id }); + }) + .ok(); + } + }); + + self.scripting_tool_use + .run_pending_tool(tool_use_id, insert_output_task); + } + pub fn send_tool_results_to_model( &mut self, model: Arc, diff --git a/crates/assistant2/src/tool_use.rs b/crates/assistant2/src/tool_use.rs index 8340febac1..4161797dc2 100644 --- a/crates/assistant2/src/tool_use.rs +++ b/crates/assistant2/src/tool_use.rs @@ -267,6 +267,7 @@ impl ToolUseState { pub struct PendingToolUse { pub id: LanguageModelToolUseId, /// The ID of the Assistant message in which the tool use was requested. + #[allow(unused)] pub assistant_message_id: MessageId, pub name: Arc, pub input: serde_json::Value, diff --git a/crates/scripting_tool/Cargo.toml b/crates/scripting_tool/Cargo.toml index 18ea42eb87..d2ee28b4e6 100644 --- a/crates/scripting_tool/Cargo.toml +++ b/crates/scripting_tool/Cargo.toml @@ -14,7 +14,6 @@ doctest = false [dependencies] anyhow.workspace = true -assistant_tool.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true diff --git a/crates/scripting_tool/src/scripting_tool.rs b/crates/scripting_tool/src/scripting_tool.rs index 53a4cf944c..885240c9a6 100644 --- a/crates/scripting_tool/src/scripting_tool.rs +++ b/crates/scripting_tool/src/scripting_tool.rs @@ -3,40 +3,29 @@ mod session; use project::Project; use session::*; -use assistant_tool::{Tool, ToolRegistry}; use gpui::{App, AppContext as _, Entity, Task}; use schemars::JsonSchema; use serde::Deserialize; -use std::sync::Arc; - -pub fn init(cx: &App) { - let registry = ToolRegistry::global(cx); - registry.register_tool(ScriptingTool); -} #[derive(Debug, Deserialize, JsonSchema)] struct ScriptingToolInput { lua_script: String, } -struct ScriptingTool; +pub struct ScriptingTool; -impl Tool for ScriptingTool { - fn name(&self) -> String { - "lua-interpreter".into() - } +impl ScriptingTool { + pub const NAME: &str = "lua-interpreter"; - fn description(&self) -> String { - include_str!("scripting_tool_description.txt").into() - } + pub const DESCRIPTION: &str = include_str!("scripting_tool_description.txt"); - fn input_schema(&self) -> serde_json::Value { + pub fn input_schema() -> serde_json::Value { let schema = schemars::schema_for!(ScriptingToolInput); serde_json::to_value(&schema).unwrap() } - fn run( - self: Arc, + pub fn run( + &self, input: serde_json::Value, project: Entity, cx: &mut App, diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 4699a42824..4fcce5cd15 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -98,7 +98,6 @@ remote.workspace = true repl.workspace = true reqwest_client.workspace = true rope.workspace = true -scripting_tool.workspace = true search.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 901cecae4b..4133a15867 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -476,7 +476,6 @@ fn main() { cx, ); assistant_tools::init(cx); - scripting_tool::init(cx); repl::init(app_state.fs.clone(), cx); extension_host::init( extension_host_proxy,