assistant: Add tool registry (#17331)
This PR adds a tool registry to hold tools that can be called by the Assistant. Currently we just have a `now` tool for retrieving the current datetime. This is all behind the `assistant-tool-use` feature flag which currently needs to be explicitly opted-in to in order for the LLM to see the tools. Release Notes: - N/A
This commit is contained in:
parent
c2448e1673
commit
e81b484bf2
11 changed files with 243 additions and 2 deletions
15
Cargo.lock
generated
15
Cargo.lock
generated
|
@ -373,6 +373,7 @@ dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assets",
|
"assets",
|
||||||
"assistant_slash_command",
|
"assistant_slash_command",
|
||||||
|
"assistant_tool",
|
||||||
"async-watch",
|
"async-watch",
|
||||||
"cargo_toml",
|
"cargo_toml",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
@ -454,6 +455,20 @@ dependencies = [
|
||||||
"workspace",
|
"workspace",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "assistant_tool"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"collections",
|
||||||
|
"derive_more",
|
||||||
|
"gpui",
|
||||||
|
"parking_lot",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"workspace",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-attributes"
|
name = "async-attributes"
|
||||||
version = "1.1.2"
|
version = "1.1.2"
|
||||||
|
|
|
@ -6,6 +6,7 @@ members = [
|
||||||
"crates/assets",
|
"crates/assets",
|
||||||
"crates/assistant",
|
"crates/assistant",
|
||||||
"crates/assistant_slash_command",
|
"crates/assistant_slash_command",
|
||||||
|
"crates/assistant_tool",
|
||||||
"crates/audio",
|
"crates/audio",
|
||||||
"crates/auto_update",
|
"crates/auto_update",
|
||||||
"crates/breadcrumbs",
|
"crates/breadcrumbs",
|
||||||
|
@ -181,6 +182,7 @@ anthropic = { path = "crates/anthropic" }
|
||||||
assets = { path = "crates/assets" }
|
assets = { path = "crates/assets" }
|
||||||
assistant = { path = "crates/assistant" }
|
assistant = { path = "crates/assistant" }
|
||||||
assistant_slash_command = { path = "crates/assistant_slash_command" }
|
assistant_slash_command = { path = "crates/assistant_slash_command" }
|
||||||
|
assistant_tool = { path = "crates/assistant_tool" }
|
||||||
audio = { path = "crates/audio" }
|
audio = { path = "crates/audio" }
|
||||||
auto_update = { path = "crates/auto_update" }
|
auto_update = { path = "crates/auto_update" }
|
||||||
breadcrumbs = { path = "crates/breadcrumbs" }
|
breadcrumbs = { path = "crates/breadcrumbs" }
|
||||||
|
|
|
@ -25,6 +25,7 @@ anthropic = { workspace = true, features = ["schemars"] }
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
assets.workspace = true
|
assets.workspace = true
|
||||||
assistant_slash_command.workspace = true
|
assistant_slash_command.workspace = true
|
||||||
|
assistant_tool.workspace = true
|
||||||
async-watch.workspace = true
|
async-watch.workspace = true
|
||||||
cargo_toml.workspace = true
|
cargo_toml.workspace = true
|
||||||
chrono.workspace = true
|
chrono.workspace = true
|
||||||
|
|
|
@ -13,11 +13,13 @@ pub(crate) mod slash_command_picker;
|
||||||
pub mod slash_command_settings;
|
pub mod slash_command_settings;
|
||||||
mod streaming_diff;
|
mod streaming_diff;
|
||||||
mod terminal_inline_assistant;
|
mod terminal_inline_assistant;
|
||||||
|
mod tools;
|
||||||
mod workflow;
|
mod workflow;
|
||||||
|
|
||||||
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
|
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
|
||||||
use assistant_settings::AssistantSettings;
|
use assistant_settings::AssistantSettings;
|
||||||
use assistant_slash_command::SlashCommandRegistry;
|
use assistant_slash_command::SlashCommandRegistry;
|
||||||
|
use assistant_tool::ToolRegistry;
|
||||||
use client::{proto, Client};
|
use client::{proto, Client};
|
||||||
use command_palette_hooks::CommandPaletteFilter;
|
use command_palette_hooks::CommandPaletteFilter;
|
||||||
pub use context::*;
|
pub use context::*;
|
||||||
|
@ -214,6 +216,7 @@ pub fn init(
|
||||||
prompt_library::init(cx);
|
prompt_library::init(cx);
|
||||||
init_language_model_settings(cx);
|
init_language_model_settings(cx);
|
||||||
assistant_slash_command::init(cx);
|
assistant_slash_command::init(cx);
|
||||||
|
assistant_tool::init(cx);
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
context_servers::init(cx);
|
context_servers::init(cx);
|
||||||
|
|
||||||
|
@ -228,6 +231,7 @@ pub fn init(
|
||||||
.map(Arc::new)
|
.map(Arc::new)
|
||||||
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
|
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
|
||||||
register_slash_commands(Some(prompt_builder.clone()), cx);
|
register_slash_commands(Some(prompt_builder.clone()), cx);
|
||||||
|
register_tools(cx);
|
||||||
inline_assistant::init(
|
inline_assistant::init(
|
||||||
fs.clone(),
|
fs.clone(),
|
||||||
prompt_builder.clone(),
|
prompt_builder.clone(),
|
||||||
|
@ -401,6 +405,11 @@ fn update_slash_commands_from_settings(cx: &mut AppContext) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn register_tools(cx: &mut AppContext) {
|
||||||
|
let tool_registry = ToolRegistry::global(cx);
|
||||||
|
tool_registry.register_tool(tools::now_tool::NowTool);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn humanize_token_count(count: usize) -> String {
|
pub fn humanize_token_count(count: usize) -> String {
|
||||||
match count {
|
match count {
|
||||||
0..=999 => count.to_string(),
|
0..=999 => count.to_string(),
|
||||||
|
|
|
@ -9,9 +9,11 @@ use anyhow::{anyhow, Context as _, Result};
|
||||||
use assistant_slash_command::{
|
use assistant_slash_command::{
|
||||||
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
|
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
|
||||||
};
|
};
|
||||||
|
use assistant_tool::ToolRegistry;
|
||||||
use client::{self, proto, telemetry::Telemetry};
|
use client::{self, proto, telemetry::Telemetry};
|
||||||
use clock::ReplicaId;
|
use clock::ReplicaId;
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
|
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
||||||
use fs::{Fs, RemoveOptions};
|
use fs::{Fs, RemoveOptions};
|
||||||
use futures::{
|
use futures::{
|
||||||
future::{self, Shared},
|
future::{self, Shared},
|
||||||
|
@ -27,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
||||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
MessageContent, Role,
|
LanguageModelRequestTool, MessageContent, Role,
|
||||||
};
|
};
|
||||||
use open_ai::Model as OpenAiModel;
|
use open_ai::Model as OpenAiModel;
|
||||||
use paths::{context_images_dir, contexts_dir};
|
use paths::{context_images_dir, contexts_dir};
|
||||||
|
@ -1942,7 +1944,21 @@ impl Context {
|
||||||
// Compute which messages to cache, including the last one.
|
// Compute which messages to cache, including the last one.
|
||||||
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
|
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
|
||||||
|
|
||||||
let request = self.to_completion_request(cx);
|
let mut request = self.to_completion_request(cx);
|
||||||
|
|
||||||
|
if cx.has_flag::<ToolUseFeatureFlag>() {
|
||||||
|
let tool_registry = ToolRegistry::global(cx);
|
||||||
|
request.tools = tool_registry
|
||||||
|
.tools()
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool| LanguageModelRequestTool {
|
||||||
|
name: tool.name(),
|
||||||
|
description: tool.description(),
|
||||||
|
input_schema: tool.input_schema(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
}
|
||||||
|
|
||||||
let assistant_message = self
|
let assistant_message = self
|
||||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -2788,6 +2804,16 @@ pub enum PendingSlashCommandStatus {
|
||||||
Error(String),
|
Error(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) struct ToolUseFeatureFlag;
|
||||||
|
|
||||||
|
impl FeatureFlag for ToolUseFeatureFlag {
|
||||||
|
const NAME: &'static str = "assistant-tool-use";
|
||||||
|
|
||||||
|
fn enabled_for_staff() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct PendingToolUse {
|
pub struct PendingToolUse {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
|
1
crates/assistant/src/tools.rs
Normal file
1
crates/assistant/src/tools.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
pub mod now_tool;
|
60
crates/assistant/src/tools/now_tool.rs
Normal file
60
crates/assistant/src/tools/now_tool.rs
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use assistant_tool::Tool;
|
||||||
|
use chrono::{Local, Utc};
|
||||||
|
use gpui::{Task, WeakView, WindowContext};
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum Timezone {
|
||||||
|
/// Use UTC for the datetime.
|
||||||
|
Utc,
|
||||||
|
/// Use local time for the datetime.
|
||||||
|
Local,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||||
|
pub struct FileToolInput {
|
||||||
|
/// The timezone to use for the datetime.
|
||||||
|
timezone: Timezone,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct NowTool;
|
||||||
|
|
||||||
|
impl Tool for NowTool {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"now".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> String {
|
||||||
|
"Returns the current datetime in RFC 3339 format.".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
|
let schema = schemars::schema_for!(FileToolInput);
|
||||||
|
serde_json::to_value(&schema).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
_workspace: WeakView<workspace::Workspace>,
|
||||||
|
_cx: &mut WindowContext,
|
||||||
|
) -> Task<Result<String>> {
|
||||||
|
let input: FileToolInput = match serde_json::from_value(input) {
|
||||||
|
Ok(input) => input,
|
||||||
|
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||||
|
};
|
||||||
|
|
||||||
|
let now = match input.timezone {
|
||||||
|
Timezone::Utc => Utc::now().to_rfc3339(),
|
||||||
|
Timezone::Local => Local::now().to_rfc3339(),
|
||||||
|
};
|
||||||
|
let text = format!("The current datetime is {now}.");
|
||||||
|
|
||||||
|
Task::ready(Ok(text))
|
||||||
|
}
|
||||||
|
}
|
22
crates/assistant_tool/Cargo.toml
Normal file
22
crates/assistant_tool/Cargo.toml
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
[package]
|
||||||
|
name = "assistant_tool"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
license = "GPL-3.0-or-later"
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/assistant_tool.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow.workspace = true
|
||||||
|
collections.workspace = true
|
||||||
|
derive_more.workspace = true
|
||||||
|
gpui.workspace = true
|
||||||
|
parking_lot.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
workspace.workspace = true
|
1
crates/assistant_tool/LICENSE-GPL
Symbolic link
1
crates/assistant_tool/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../../LICENSE-GPL
|
35
crates/assistant_tool/src/assistant_tool.rs
Normal file
35
crates/assistant_tool/src/assistant_tool.rs
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
mod tool_registry;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use gpui::{AppContext, Task, WeakView, WindowContext};
|
||||||
|
use workspace::Workspace;
|
||||||
|
|
||||||
|
pub use tool_registry::*;
|
||||||
|
|
||||||
|
pub fn init(cx: &mut AppContext) {
|
||||||
|
ToolRegistry::default_global(cx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A tool that can be used by a language model.
|
||||||
|
pub trait Tool: 'static + Send + Sync {
|
||||||
|
/// Returns the name of the tool.
|
||||||
|
fn name(&self) -> String;
|
||||||
|
|
||||||
|
/// Returns the description of the tool.
|
||||||
|
fn description(&self) -> String;
|
||||||
|
|
||||||
|
/// Returns the JSON schema that describes the tool's input.
|
||||||
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::Value::Object(serde_json::Map::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs the tool with the provided input.
|
||||||
|
fn run(
|
||||||
|
self: Arc<Self>,
|
||||||
|
input: serde_json::Value,
|
||||||
|
workspace: WeakView<Workspace>,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> Task<Result<String>>;
|
||||||
|
}
|
69
crates/assistant_tool/src/tool_registry.rs
Normal file
69
crates/assistant_tool/src/tool_registry.rs
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use collections::HashMap;
|
||||||
|
use derive_more::{Deref, DerefMut};
|
||||||
|
use gpui::Global;
|
||||||
|
use gpui::{AppContext, ReadGlobal};
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
|
||||||
|
use crate::Tool;
|
||||||
|
|
||||||
|
#[derive(Default, Deref, DerefMut)]
|
||||||
|
struct GlobalToolRegistry(Arc<ToolRegistry>);
|
||||||
|
|
||||||
|
impl Global for GlobalToolRegistry {}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct ToolRegistryState {
|
||||||
|
tools: HashMap<Arc<str>, Arc<dyn Tool>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ToolRegistry {
|
||||||
|
state: RwLock<ToolRegistryState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistry {
|
||||||
|
/// Returns the global [`ToolRegistry`].
|
||||||
|
pub fn global(cx: &AppContext) -> Arc<Self> {
|
||||||
|
GlobalToolRegistry::global(cx).0.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the global [`ToolRegistry`].
|
||||||
|
///
|
||||||
|
/// Inserts a default [`ToolRegistry`] if one does not yet exist.
|
||||||
|
pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
|
||||||
|
cx.default_global::<GlobalToolRegistry>().0.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new() -> Arc<Self> {
|
||||||
|
Arc::new(Self {
|
||||||
|
state: RwLock::new(ToolRegistryState {
|
||||||
|
tools: HashMap::default(),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Registers the provided [`Tool`].
|
||||||
|
pub fn register_tool(&self, tool: impl Tool) {
|
||||||
|
let mut state = self.state.write();
|
||||||
|
let tool_name: Arc<str> = tool.name().into();
|
||||||
|
state.tools.insert(tool_name, Arc::new(tool));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unregisters the provided [`Tool`].
|
||||||
|
pub fn unregister_tool(&self, tool: impl Tool) {
|
||||||
|
self.unregister_tool_by_name(tool.name().as_str())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unregisters the tool with the given name.
|
||||||
|
pub fn unregister_tool_by_name(&self, tool_name: &str) {
|
||||||
|
let mut state = self.state.write();
|
||||||
|
state.tools.remove(tool_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the list of tools in the registry.
|
||||||
|
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
|
||||||
|
self.state.read().tools.values().cloned().collect()
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue