assistant: Add tool registry (#17331)

This PR adds a tool registry to hold tools that can be called by the
Assistant.

Currently we just have a `now` tool for retrieving the current datetime.

This is all behind the `assistant-tool-use` feature flag which currently
needs to be explicitly opted-in to in order for the LLM to see the
tools.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-09-03 19:14:36 -04:00 committed by GitHub
parent c2448e1673
commit e81b484bf2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 243 additions and 2 deletions

View file

@ -0,0 +1,35 @@
mod tool_registry;
use std::sync::Arc;
use anyhow::Result;
use gpui::{AppContext, Task, WeakView, WindowContext};
use workspace::Workspace;
pub use tool_registry::*;
pub fn init(cx: &mut AppContext) {
ToolRegistry::default_global(cx);
}
/// A tool that can be used by a language model.
pub trait Tool: 'static + Send + Sync {
/// Returns the name of the tool.
fn name(&self) -> String;
/// Returns the description of the tool.
fn description(&self) -> String;
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self) -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::default())
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: serde_json::Value,
workspace: WeakView<Workspace>,
cx: &mut WindowContext,
) -> Task<Result<String>>;
}

View file

@ -0,0 +1,69 @@
use std::sync::Arc;
use collections::HashMap;
use derive_more::{Deref, DerefMut};
use gpui::Global;
use gpui::{AppContext, ReadGlobal};
use parking_lot::RwLock;
use crate::Tool;
#[derive(Default, Deref, DerefMut)]
struct GlobalToolRegistry(Arc<ToolRegistry>);
impl Global for GlobalToolRegistry {}
#[derive(Default)]
struct ToolRegistryState {
tools: HashMap<Arc<str>, Arc<dyn Tool>>,
}
#[derive(Default)]
pub struct ToolRegistry {
state: RwLock<ToolRegistryState>,
}
impl ToolRegistry {
/// Returns the global [`ToolRegistry`].
pub fn global(cx: &AppContext) -> Arc<Self> {
GlobalToolRegistry::global(cx).0.clone()
}
/// Returns the global [`ToolRegistry`].
///
/// Inserts a default [`ToolRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
cx.default_global::<GlobalToolRegistry>().0.clone()
}
pub fn new() -> Arc<Self> {
Arc::new(Self {
state: RwLock::new(ToolRegistryState {
tools: HashMap::default(),
}),
})
}
/// Registers the provided [`Tool`].
pub fn register_tool(&self, tool: impl Tool) {
let mut state = self.state.write();
let tool_name: Arc<str> = tool.name().into();
state.tools.insert(tool_name, Arc::new(tool));
}
/// Unregisters the provided [`Tool`].
pub fn unregister_tool(&self, tool: impl Tool) {
self.unregister_tool_by_name(tool.name().as_str())
}
/// Unregisters the tool with the given name.
pub fn unregister_tool_by_name(&self, tool_name: &str) {
let mut state = self.state.write();
state.tools.remove(tool_name);
}
/// Returns the list of tools in the registry.
pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
self.state.read().tools.values().cloned().collect()
}
}