assistant: Add tool registry (#17331)
This PR adds a tool registry to hold tools that can be called by the Assistant. Currently we just have a `now` tool for retrieving the current datetime. This is all behind the `assistant-tool-use` feature flag which currently needs to be explicitly opted-in to in order for the LLM to see the tools. Release Notes: - N/A
This commit is contained in:
parent
c2448e1673
commit
e81b484bf2
11 changed files with 243 additions and 2 deletions
22
crates/assistant_tool/Cargo.toml
Normal file
22
crates/assistant_tool/Cargo.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "assistant_tool"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/assistant_tool.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
derive_more.workspace = true
|
||||
gpui.workspace = true
|
||||
parking_lot.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
workspace.workspace = true
|
1
crates/assistant_tool/LICENSE-GPL
Symbolic link
1
crates/assistant_tool/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
35
crates/assistant_tool/src/assistant_tool.rs
Normal file
35
crates/assistant_tool/src/assistant_tool.rs
Normal file
|
@ -0,0 +1,35 @@
|
|||
mod tool_registry;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::{AppContext, Task, WeakView, WindowContext};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub use tool_registry::*;
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
ToolRegistry::default_global(cx);
|
||||
}
|
||||
|
||||
/// A tool that can be used by a language model.
|
||||
pub trait Tool: 'static + Send + Sync {
|
||||
/// Returns the name of the tool.
|
||||
fn name(&self) -> String;
|
||||
|
||||
/// Returns the description of the tool.
|
||||
fn description(&self) -> String;
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
serde_json::Value::Object(serde_json::Map::default())
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
workspace: WeakView<Workspace>,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<String>>;
|
||||
}
|
69
crates/assistant_tool/src/tool_registry.rs
Normal file
69
crates/assistant_tool/src/tool_registry.rs
Normal file
|
@ -0,0 +1,69 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use collections::HashMap;
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use gpui::Global;
|
||||
use gpui::{AppContext, ReadGlobal};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use crate::Tool;
|
||||
|
||||
#[derive(Default, Deref, DerefMut)]
|
||||
struct GlobalToolRegistry(Arc<ToolRegistry>);
|
||||
|
||||
impl Global for GlobalToolRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
struct ToolRegistryState {
|
||||
tools: HashMap<Arc<str>, Arc<dyn Tool>>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ToolRegistry {
|
||||
state: RwLock<ToolRegistryState>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
/// Returns the global [`ToolRegistry`].
|
||||
pub fn global(cx: &AppContext) -> Arc<Self> {
|
||||
GlobalToolRegistry::global(cx).0.clone()
|
||||
}
|
||||
|
||||
/// Returns the global [`ToolRegistry`].
|
||||
///
|
||||
/// Inserts a default [`ToolRegistry`] if one does not yet exist.
|
||||
pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
|
||||
cx.default_global::<GlobalToolRegistry>().0.clone()
|
||||
}
|
||||
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
state: RwLock::new(ToolRegistryState {
|
||||
tools: HashMap::default(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
/// Registers the provided [`Tool`].
|
||||
pub fn register_tool(&self, tool: impl Tool) {
|
||||
let mut state = self.state.write();
|
||||
let tool_name: Arc<str> = tool.name().into();
|
||||
state.tools.insert(tool_name, Arc::new(tool));
|
||||
}
|
||||
|
||||
/// Unregisters the provided [`Tool`].
|
||||
pub fn unregister_tool(&self, tool: impl Tool) {
|
||||
self.unregister_tool_by_name(tool.name().as_str())
|
||||
}
|
||||
|
||||
/// Unregisters the tool with the given name.
|
||||
pub fn unregister_tool_by_name(&self, tool_name: &str) {
|
||||
let mut state = self.state.write();
|
||||
state.tools.remove(tool_name);
|
||||
}
|
||||
|
||||
/// Returns the list of tools in the registry.
|
||||
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
|
||||
self.state.read().tools.values().cloned().collect()
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue