diff --git a/Cargo.lock b/Cargo.lock index 48be65f7b8..737c914667 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -373,6 +373,7 @@ dependencies = [ "anyhow", "assets", "assistant_slash_command", + "assistant_tool", "async-watch", "cargo_toml", "chrono", @@ -454,6 +455,20 @@ dependencies = [ "workspace", ] +[[package]] +name = "assistant_tool" +version = "0.1.0" +dependencies = [ + "anyhow", + "collections", + "derive_more", + "gpui", + "parking_lot", + "serde", + "serde_json", + "workspace", +] + [[package]] name = "async-attributes" version = "1.1.2" diff --git a/Cargo.toml b/Cargo.toml index 0a29d72894..68a7167e70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "crates/assets", "crates/assistant", "crates/assistant_slash_command", + "crates/assistant_tool", "crates/audio", "crates/auto_update", "crates/breadcrumbs", @@ -181,6 +182,7 @@ anthropic = { path = "crates/anthropic" } assets = { path = "crates/assets" } assistant = { path = "crates/assistant" } assistant_slash_command = { path = "crates/assistant_slash_command" } +assistant_tool = { path = "crates/assistant_tool" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } breadcrumbs = { path = "crates/breadcrumbs" } diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 08de8ad694..d2b5aed9bd 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -25,6 +25,7 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true assets.workspace = true assistant_slash_command.workspace = true +assistant_tool.workspace = true async-watch.workspace = true cargo_toml.workspace = true chrono.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index a962ace527..39df01d06c 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -13,11 +13,13 @@ pub(crate) mod slash_command_picker; pub mod slash_command_settings; mod streaming_diff; mod terminal_inline_assistant; +mod tools; mod workflow; pub use assistant_panel::{AssistantPanel, AssistantPanelEvent}; use assistant_settings::AssistantSettings; use assistant_slash_command::SlashCommandRegistry; +use assistant_tool::ToolRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; pub use context::*; @@ -214,6 +216,7 @@ pub fn init( prompt_library::init(cx); init_language_model_settings(cx); assistant_slash_command::init(cx); + assistant_tool::init(cx); assistant_panel::init(cx); context_servers::init(cx); @@ -228,6 +231,7 @@ pub fn init( .map(Arc::new) .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap())); register_slash_commands(Some(prompt_builder.clone()), cx); + register_tools(cx); inline_assistant::init( fs.clone(), prompt_builder.clone(), @@ -401,6 +405,11 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) { } } +fn register_tools(cx: &mut AppContext) { + let tool_registry = ToolRegistry::global(cx); + tool_registry.register_tool(tools::now_tool::NowTool); +} + pub fn humanize_token_count(count: usize) -> String { match count { 0..=999 => count.to_string(), diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index cb0d45c3b3..51e7d626d7 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -9,9 +9,11 @@ use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry, }; +use assistant_tool::ToolRegistry; use client::{self, proto, telemetry::Telemetry}; use clock::ReplicaId; use collections::{HashMap, HashSet}; +use feature_flags::{FeatureFlag, FeatureFlagAppExt}; use fs::{Fs, RemoveOptions}; use futures::{ future::{self, Shared}, @@ -27,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P use language_model::{ LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - MessageContent, Role, + LanguageModelRequestTool, MessageContent, Role, }; use open_ai::Model as OpenAiModel; use paths::{context_images_dir, contexts_dir}; @@ -1942,7 +1944,21 @@ impl Context { // Compute which messages to cache, including the last one. self.mark_cache_anchors(&model.cache_configuration(), false, cx); - let request = self.to_completion_request(cx); + let mut request = self.to_completion_request(cx); + + if cx.has_flag::() { + let tool_registry = ToolRegistry::global(cx); + request.tools = tool_registry + .tools() + .into_iter() + .map(|tool| LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema: tool.input_schema(), + }) + .collect(); + } + let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -2788,6 +2804,16 @@ pub enum PendingSlashCommandStatus { Error(String), } +pub(crate) struct ToolUseFeatureFlag; + +impl FeatureFlag for ToolUseFeatureFlag { + const NAME: &'static str = "assistant-tool-use"; + + fn enabled_for_staff() -> bool { + false + } +} + #[derive(Debug, Clone)] pub struct PendingToolUse { pub id: String, diff --git a/crates/assistant/src/tools.rs b/crates/assistant/src/tools.rs new file mode 100644 index 0000000000..abde04e760 --- /dev/null +++ b/crates/assistant/src/tools.rs @@ -0,0 +1 @@ +pub mod now_tool; diff --git a/crates/assistant/src/tools/now_tool.rs b/crates/assistant/src/tools/now_tool.rs new file mode 100644 index 0000000000..99034321b1 --- /dev/null +++ b/crates/assistant/src/tools/now_tool.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use assistant_tool::Tool; +use chrono::{Local, Utc}; +use gpui::{Task, WeakView, WindowContext}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum Timezone { + /// Use UTC for the datetime. + Utc, + /// Use local time for the datetime. + Local, +} + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FileToolInput { + /// The timezone to use for the datetime. + timezone: Timezone, +} + +pub struct NowTool; + +impl Tool for NowTool { + fn name(&self) -> String { + "now".into() + } + + fn description(&self) -> String { + "Returns the current datetime in RFC 3339 format.".into() + } + + fn input_schema(&self) -> serde_json::Value { + let schema = schemars::schema_for!(FileToolInput); + serde_json::to_value(&schema).unwrap() + } + + fn run( + self: Arc, + input: serde_json::Value, + _workspace: WeakView, + _cx: &mut WindowContext, + ) -> Task> { + let input: FileToolInput = match serde_json::from_value(input) { + Ok(input) => input, + Err(err) => return Task::ready(Err(anyhow!(err))), + }; + + let now = match input.timezone { + Timezone::Utc => Utc::now().to_rfc3339(), + Timezone::Local => Local::now().to_rfc3339(), + }; + let text = format!("The current datetime is {now}."); + + Task::ready(Ok(text)) + } +} diff --git a/crates/assistant_tool/Cargo.toml b/crates/assistant_tool/Cargo.toml new file mode 100644 index 0000000000..90d0ab9142 --- /dev/null +++ b/crates/assistant_tool/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "assistant_tool" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/assistant_tool.rs" + +[dependencies] +anyhow.workspace = true +collections.workspace = true +derive_more.workspace = true +gpui.workspace = true +parking_lot.workspace = true +serde.workspace = true +serde_json.workspace = true +workspace.workspace = true diff --git a/crates/assistant_tool/LICENSE-GPL b/crates/assistant_tool/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/assistant_tool/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs new file mode 100644 index 0000000000..179bfe8dd1 --- /dev/null +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -0,0 +1,35 @@ +mod tool_registry; + +use std::sync::Arc; + +use anyhow::Result; +use gpui::{AppContext, Task, WeakView, WindowContext}; +use workspace::Workspace; + +pub use tool_registry::*; + +pub fn init(cx: &mut AppContext) { + ToolRegistry::default_global(cx); +} + +/// A tool that can be used by a language model. +pub trait Tool: 'static + Send + Sync { + /// Returns the name of the tool. + fn name(&self) -> String; + + /// Returns the description of the tool. + fn description(&self) -> String; + + /// Returns the JSON schema that describes the tool's input. + fn input_schema(&self) -> serde_json::Value { + serde_json::Value::Object(serde_json::Map::default()) + } + + /// Runs the tool with the provided input. + fn run( + self: Arc, + input: serde_json::Value, + workspace: WeakView, + cx: &mut WindowContext, + ) -> Task>; +} diff --git a/crates/assistant_tool/src/tool_registry.rs b/crates/assistant_tool/src/tool_registry.rs new file mode 100644 index 0000000000..d7e1fde4c3 --- /dev/null +++ b/crates/assistant_tool/src/tool_registry.rs @@ -0,0 +1,69 @@ +use std::sync::Arc; + +use collections::HashMap; +use derive_more::{Deref, DerefMut}; +use gpui::Global; +use gpui::{AppContext, ReadGlobal}; +use parking_lot::RwLock; + +use crate::Tool; + +#[derive(Default, Deref, DerefMut)] +struct GlobalToolRegistry(Arc); + +impl Global for GlobalToolRegistry {} + +#[derive(Default)] +struct ToolRegistryState { + tools: HashMap, Arc>, +} + +#[derive(Default)] +pub struct ToolRegistry { + state: RwLock, +} + +impl ToolRegistry { + /// Returns the global [`ToolRegistry`]. + pub fn global(cx: &AppContext) -> Arc { + GlobalToolRegistry::global(cx).0.clone() + } + + /// Returns the global [`ToolRegistry`]. + /// + /// Inserts a default [`ToolRegistry`] if one does not yet exist. + pub fn default_global(cx: &mut AppContext) -> Arc { + cx.default_global::().0.clone() + } + + pub fn new() -> Arc { + Arc::new(Self { + state: RwLock::new(ToolRegistryState { + tools: HashMap::default(), + }), + }) + } + + /// Registers the provided [`Tool`]. + pub fn register_tool(&self, tool: impl Tool) { + let mut state = self.state.write(); + let tool_name: Arc = tool.name().into(); + state.tools.insert(tool_name, Arc::new(tool)); + } + + /// Unregisters the provided [`Tool`]. + pub fn unregister_tool(&self, tool: impl Tool) { + self.unregister_tool_by_name(tool.name().as_str()) + } + + /// Unregisters the tool with the given name. + pub fn unregister_tool_by_name(&self, tool_name: &str) { + let mut state = self.state.write(); + state.tools.remove(tool_name); + } + + /// Returns the list of tools in the registry. + pub fn tools(&self) -> Vec> { + self.state.read().tools.values().cloned().collect() + } +}