ZIm/crates/assistant_tooling/src/tool_registry.rs
Max Brunsfeld a7aa2578e1
Implement serialization of assistant conversations, including tool calls and attachments (#11577)
Release Notes:

- N/A

---------

Co-authored-by: Kyle <kylek@zed.dev>
Co-authored-by: Marshall <marshall@zed.dev>
2024-05-08 17:52:15 -04:00

581 lines
19 KiB
Rust

use crate::ProjectContext;
use anyhow::{anyhow, Result};
use gpui::{
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
};
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{value::RawValue, Value};
use std::{
any::TypeId,
collections::HashMap,
fmt::Display,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
pub struct ToolRegistry {
registered_tools: HashMap<String, RegisteredTool>,
}
#[derive(Default)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
pub result: Option<ToolFunctionCallResult>,
}
#[derive(Default, Serialize, Deserialize)]
pub struct SavedToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
pub result: Option<SavedToolFunctionCallResult>,
}
pub enum ToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished {
view: AnyView,
serialized_output: Result<Box<RawValue>, String>,
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
},
}
#[derive(Serialize, Deserialize)]
pub enum SavedToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished {
serialized_output: Result<Box<RawValue>, String>,
},
}
#[derive(Clone)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
pub parameters: RootSchema,
}
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: DeserializeOwned + JsonSchema;
/// The output returned by executing the tool.
type Output: DeserializeOwned + Serialize + 'static;
type View: Render + ToolOutput;
/// Returns the name of the tool.
///
/// This name is exposed to the language model to allow the model to pick
/// which tools to use. As this name is used to identify the tool within a
/// tool registry, it should be unique.
fn name(&self) -> String;
/// Returns the description of the tool.
///
/// This can be used to _prompt_ the model as to what the tool does.
fn description(&self) -> String;
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
let root_schema = schema_for!(Self::Input);
ToolFunctionDefinition {
name: self.name(),
description: self.description(),
parameters: root_schema,
}
}
/// Executes the tool with the given input.
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
/// A view of the output of running the tool, for displaying to the user.
fn view(
&self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View>;
fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
tool_running_placeholder()
}
}
pub fn tool_running_placeholder() -> AnyElement {
ui::Label::new("Researching...").into_any_element()
}
pub trait ToolOutput: Sized {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
}
struct RegisteredTool {
enabled: AtomicBool,
type_id: TypeId,
execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
definition: ToolFunctionDefinition,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
registered_tools: HashMap::new(),
}
}
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() {
tool.enabled.store(is_enabled, SeqCst);
return;
}
}
}
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() {
return tool.enabled.load(SeqCst);
}
}
false
}
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
self.registered_tools
.values()
.filter(|tool| tool.enabled.load(SeqCst))
.map(|tool| tool.definition.clone())
.collect()
}
pub fn render_tool_call(
&self,
tool_call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> AnyElement {
match &tool_call.result {
Some(result) => div()
.p_2()
.child(result.into_any_element(&tool_call.name))
.into_any_element(),
None => {
let tool = self.registered_tools.get(&tool_call.name);
if let Some(tool) = tool {
(tool.render_running)(&tool_call, cx)
} else {
tool_running_placeholder()
}
}
}
}
pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
SavedToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: call.result.as_ref().map(|result| match result {
ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
ToolFunctionCallResult::Finished {
serialized_output, ..
} => SavedToolFunctionCallResult::Finished {
serialized_output: match serialized_output {
Ok(value) => Ok(value.clone()),
Err(e) => Err(e.to_string()),
},
},
}),
}
}
pub fn deserialize_tool_call(
&self,
call: &SavedToolFunctionCall,
cx: &mut WindowContext,
) -> ToolFunctionCall {
if let Some(tool) = &self.registered_tools.get(&call.name) {
(tool.deserialize)(call, cx)
} else {
ToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: Some(ToolFunctionCallResult::NoSuchTool),
}
}
}
pub fn register<T: 'static + LanguageModelTool>(
&mut self,
tool: T,
_cx: &mut WindowContext,
) -> Result<()> {
let name = tool.name();
let tool = Arc::new(tool);
let registered_tool = RegisteredTool {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
enabled: AtomicBool::new(true),
deserialize: Box::new({
let tool = tool.clone();
move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
let id = tool_call.id.clone();
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
return ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ParsingFailed),
};
};
let result = match &tool_call.result {
Some(result) => match result {
SavedToolFunctionCallResult::NoSuchTool => {
Some(ToolFunctionCallResult::NoSuchTool)
}
SavedToolFunctionCallResult::ParsingFailed => {
Some(ToolFunctionCallResult::ParsingFailed)
}
SavedToolFunctionCallResult::Finished { serialized_output } => {
let output = match serialized_output {
Ok(value) => {
match serde_json::from_str::<T::Output>(value.get()) {
Ok(value) => Ok(value),
Err(_) => {
return ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(
ToolFunctionCallResult::ParsingFailed,
),
};
}
}
}
Err(e) => Err(anyhow!("{e}")),
};
let view = tool.view(input, output, cx).into();
Some(ToolFunctionCallResult::Finished {
serialized_output: serialized_output.clone(),
generate_fn: generate::<T>,
view,
})
}
},
None => None,
};
ToolFunctionCall {
id: tool_call.id.clone(),
name: name.clone(),
arguments: tool_call.arguments.clone(),
result,
}
}
}),
execute: Box::new({
let tool = tool.clone();
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let id = tool_call.id.clone();
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
return Task::ready(Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ParsingFailed),
}));
};
let result = tool.execute(&input, cx);
let tool = tool.clone();
cx.spawn(move |mut cx| async move {
let result = result.await;
let serialized_output = result
.as_ref()
.map_err(ToString::to_string)
.and_then(|output| {
Ok(RawValue::from_string(
serde_json::to_string(output).map_err(|e| e.to_string())?,
)
.unwrap())
});
let view = cx.update(|cx| tool.view(input, result, cx))?;
Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::Finished {
serialized_output,
view: view.into(),
generate_fn: generate::<T>,
}),
})
})
}
}),
render_running: render_running::<T>,
};
let previous = self.registered_tools.insert(name.clone(), registered_tool);
if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name));
}
return Ok(());
fn render_running<T: LanguageModelTool>(
tool_call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> AnyElement {
// Attempt to parse the string arguments that are JSON as a JSON value
let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
T::render_running(&maybe_arguments, cx).into_any_element()
}
fn generate<T: LanguageModelTool>(
view: AnyView,
project: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
view.downcast::<T::View>()
.unwrap()
.update(cx, |view, cx| T::View::generate(view, project, cx))
}
}
/// Task yields an error if the window for the given WindowContext is closed before the task completes.
pub fn call(
&self,
tool_call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> Task<Result<ToolFunctionCall>> {
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
let tool = match self.registered_tools.get(&name) {
Some(tool) => tool,
None => {
let name = name.clone();
return Task::ready(Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::NoSuchTool),
}));
}
};
(tool.execute)(tool_call, cx)
}
}
impl ToolFunctionCallResult {
pub fn generate(
&self,
name: &String,
project: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
match self {
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
ToolFunctionCallResult::ParsingFailed => {
format!("Unable to parse arguments for {name}")
}
ToolFunctionCallResult::Finished {
generate_fn, view, ..
} => (generate_fn)(view.clone(), project, cx),
}
}
fn into_any_element(&self, name: &String) -> AnyElement {
match self {
ToolFunctionCallResult::NoSuchTool => {
format!("Language Model attempted to call {name}").into_any_element()
}
ToolFunctionCallResult::ParsingFailed => {
format!("Language Model called {name} with bad arguments").into_any_element()
}
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
}
}
}
impl Display for ToolFunctionDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let schema = serde_json::to_string(&self.parameters).ok();
let schema = schema.unwrap_or("None".to_string());
write!(f, "Name: {}:\n", self.name)?;
write!(f, "Description: {}\n", self.description)?;
write!(f, "Parameters: {}", schema)
}
}
#[cfg(test)]
mod test {
use super::*;
use gpui::{div, prelude::*, Render, TestAppContext};
use gpui::{EmptyView, View};
use schemars::schema_for;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize, Serialize, JsonSchema)]
struct WeatherQuery {
location: String,
unit: String,
}
struct WeatherTool {
current_weather: WeatherResult,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct WeatherResult {
location: String,
temperature: f64,
unit: String,
}
struct WeatherView {
result: WeatherResult,
}
impl Render for WeatherView {
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
div().child(format!("temperature: {}", self.result.temperature))
}
}
impl ToolOutput for WeatherView {
fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
serde_json::to_string(&self.result).unwrap()
}
}
impl LanguageModelTool for WeatherTool {
type Input = WeatherQuery;
type Output = WeatherResult;
type View = WeatherView;
fn name(&self) -> String {
"get_current_weather".to_string()
}
fn description(&self) -> String {
"Fetches the current weather for a given location.".to_string()
}
fn execute(
&self,
input: &Self::Input,
_cx: &mut WindowContext,
) -> Task<Result<Self::Output>> {
let _location = input.location.clone();
let _unit = input.unit.clone();
let weather = self.current_weather.clone();
Task::ready(Ok(weather))
}
fn view(
&self,
_input: Self::Input,
result: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View> {
cx.new_view(|_cx| {
let result = result.unwrap();
WeatherView { result }
})
}
}
#[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked();
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
let tool = WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
};
let tools = vec![tool.definition()];
assert_eq!(tools.len(), 1);
let expected = ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
parameters: schema_for!(WeatherQuery),
};
assert_eq!(tools[0].name, expected.name);
assert_eq!(tools[0].description, expected.description);
let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
assert_eq!(
expected_schema,
json!({
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "WeatherQuery",
"type": "object",
"properties": {
"location": {
"type": "string"
},
"unit": {
"type": "string"
}
},
"required": ["location", "unit"]
})
);
let args = json!({
"location": "San Francisco",
"unit": "Celsius"
});
let query: WeatherQuery = serde_json::from_value(args).unwrap();
let result = cx.update(|cx| tool.execute(&query, cx)).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result, tool.current_weather);
}
}