Cleanup tool registry API surface (#11637)
Fast followups to #11629 Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
79b5556267
commit
9cef0ac869
3 changed files with 89 additions and 106 deletions
|
@ -8,6 +8,6 @@ pub use attachment_registry::{
|
|||
};
|
||||
pub use project_context::ProjectContext;
|
||||
pub use tool_registry::{
|
||||
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState,
|
||||
ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry,
|
||||
LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition, ToolOutput,
|
||||
ToolRegistry,
|
||||
};
|
||||
|
|
|
@ -9,10 +9,8 @@ use std::{
|
|||
any::TypeId,
|
||||
collections::HashMap,
|
||||
fmt::Display,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
mem,
|
||||
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
||||
};
|
||||
use ui::ViewContext;
|
||||
|
||||
|
@ -29,7 +27,7 @@ pub struct ToolFunctionCall {
|
|||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub enum ToolFunctionCallState {
|
||||
enum ToolFunctionCallState {
|
||||
#[default]
|
||||
Initializing,
|
||||
NoSuchTool,
|
||||
|
@ -37,10 +35,10 @@ pub enum ToolFunctionCallState {
|
|||
ExecutedTool(Box<dyn ToolView>),
|
||||
}
|
||||
|
||||
pub trait ToolView {
|
||||
trait ToolView {
|
||||
fn view(&self) -> AnyView;
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||
fn set_input(&self, input: &str, cx: &mut WindowContext);
|
||||
fn try_set_input(&self, input: &str, cx: &mut WindowContext);
|
||||
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
|
||||
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
|
||||
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
|
||||
|
@ -48,14 +46,14 @@ pub trait ToolView {
|
|||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub struct SavedToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
pub state: SavedToolFunctionCallState,
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
state: SavedToolFunctionCallState,
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub enum SavedToolFunctionCallState {
|
||||
enum SavedToolFunctionCallState {
|
||||
#[default]
|
||||
Initializing,
|
||||
NoSuchTool,
|
||||
|
@ -63,7 +61,7 @@ pub enum SavedToolFunctionCallState {
|
|||
ExecutedTool(Box<RawValue>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct ToolFunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
|
@ -100,18 +98,6 @@ pub trait LanguageModelTool {
|
|||
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
|
||||
}
|
||||
|
||||
pub fn tool_running_placeholder() -> AnyElement {
|
||||
ui::Label::new("Researching...").into_any_element()
|
||||
}
|
||||
|
||||
pub fn unknown_tool_placeholder() -> AnyElement {
|
||||
ui::Label::new("Unknown tool").into_any_element()
|
||||
}
|
||||
|
||||
pub fn no_such_tool_placeholder() -> AnyElement {
|
||||
ui::Label::new("No such tool").into_any_element()
|
||||
}
|
||||
|
||||
pub trait ToolOutput: Render {
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
|
@ -172,11 +158,6 @@ impl ToolRegistry {
|
|||
.collect()
|
||||
}
|
||||
|
||||
pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
|
||||
let tool = self.registered_tools.get(name)?;
|
||||
Some((tool.build_view)(cx))
|
||||
}
|
||||
|
||||
pub fn update_tool_call(
|
||||
&self,
|
||||
call: &mut ToolFunctionCall,
|
||||
|
@ -189,7 +170,8 @@ impl ToolRegistry {
|
|||
}
|
||||
if let Some(arguments) = arguments {
|
||||
if call.arguments.is_empty() {
|
||||
if let Some(view) = self.view_for_tool(&call.name, cx) {
|
||||
if let Some(tool) = self.registered_tools.get(&call.name) {
|
||||
let view = (tool.build_view)(cx);
|
||||
call.state = ToolFunctionCallState::KnownTool(view);
|
||||
} else {
|
||||
call.state = ToolFunctionCallState::NoSuchTool;
|
||||
|
@ -199,7 +181,7 @@ impl ToolRegistry {
|
|||
|
||||
if let ToolFunctionCallState::KnownTool(view) = &call.state {
|
||||
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
|
||||
view.set_input(&repaired_arguments, cx)
|
||||
view.try_set_input(&repaired_arguments, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -207,11 +189,13 @@ impl ToolRegistry {
|
|||
|
||||
pub fn execute_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
tool_call: &mut ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Option<Task<Result<()>>> {
|
||||
if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
|
||||
Some(view.execute(cx))
|
||||
if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
|
||||
let task = view.execute(cx);
|
||||
tool_call.state = ToolFunctionCallState::ExecutedTool(view);
|
||||
Some(task)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -221,12 +205,14 @@ impl ToolRegistry {
|
|||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
_cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
) -> Option<AnyElement> {
|
||||
match &tool_call.state {
|
||||
ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
|
||||
ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
|
||||
ToolFunctionCallState::NoSuchTool => {
|
||||
Some(ui::Label::new("No such tool").into_any_element())
|
||||
}
|
||||
ToolFunctionCallState::Initializing => None,
|
||||
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
|
||||
view.view().into_any_element()
|
||||
Some(view.view().into_any_element())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -287,12 +273,12 @@ impl ToolRegistry {
|
|||
SavedToolFunctionCallState::KnownTool => {
|
||||
log::error!("Deserialized tool that had not executed");
|
||||
let view = (tool.build_view)(cx);
|
||||
view.set_input(&call.arguments, cx);
|
||||
view.try_set_input(&call.arguments, cx);
|
||||
ToolFunctionCallState::KnownTool(view)
|
||||
}
|
||||
SavedToolFunctionCallState::ExecutedTool(output) => {
|
||||
let view = (tool.build_view)(cx);
|
||||
view.set_input(&call.arguments, cx);
|
||||
view.try_set_input(&call.arguments, cx);
|
||||
view.deserialize_output(output, cx)?;
|
||||
ToolFunctionCallState::ExecutedTool(view)
|
||||
}
|
||||
|
@ -300,13 +286,8 @@ impl ToolRegistry {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(
|
||||
&mut self,
|
||||
tool: T,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Result<()> {
|
||||
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
|
||||
let name = tool.name();
|
||||
let tool = Arc::new(tool);
|
||||
let registered_tool = RegisteredTool {
|
||||
type_id: TypeId::of::<T>(),
|
||||
definition: tool.definition(),
|
||||
|
@ -332,7 +313,7 @@ impl<T: ToolOutput> ToolView for View<T> {
|
|||
self.update(cx, |view, cx| view.generate(project, cx))
|
||||
}
|
||||
|
||||
fn set_input(&self, input: &str, cx: &mut WindowContext) {
|
||||
fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
|
||||
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
|
||||
self.update(cx, |view, cx| {
|
||||
view.set_input(input, cx);
|
||||
|
@ -372,7 +353,6 @@ mod test {
|
|||
use super::*;
|
||||
use gpui::{div, prelude::*, Render, TestAppContext};
|
||||
use gpui::{EmptyView, View};
|
||||
use schemars::schema_for;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
@ -483,57 +463,64 @@ mod test {
|
|||
|
||||
#[gpui::test]
|
||||
async fn test_openai_weather_example(cx: &mut TestAppContext) {
|
||||
cx.background_executor.run_until_parked();
|
||||
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
|
||||
|
||||
let tool = WeatherTool {
|
||||
current_weather: WeatherResult {
|
||||
location: "San Francisco".to_string(),
|
||||
temperature: 21.0,
|
||||
unit: "Celsius".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let tools = vec![tool.definition()];
|
||||
assert_eq!(tools.len(), 1);
|
||||
|
||||
let expected = ToolFunctionDefinition {
|
||||
name: "get_current_weather".to_string(),
|
||||
description: "Fetches the current weather for a given location.".to_string(),
|
||||
parameters: schema_for!(WeatherQuery),
|
||||
};
|
||||
|
||||
assert_eq!(tools[0].name, expected.name);
|
||||
assert_eq!(tools[0].description, expected.description);
|
||||
|
||||
let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
expected_schema,
|
||||
json!({
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "WeatherQuery",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string"
|
||||
}
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry
|
||||
.register(WeatherTool {
|
||||
current_weather: WeatherResult {
|
||||
location: "San Francisco".to_string(),
|
||||
temperature: 21.0,
|
||||
unit: "Celsius".to_string(),
|
||||
},
|
||||
"required": ["location", "unit"]
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let definitions = registry.definitions();
|
||||
assert_eq!(
|
||||
definitions,
|
||||
[ToolFunctionDefinition {
|
||||
name: "get_current_weather".to_string(),
|
||||
description: "Fetches the current weather for a given location.".to_string(),
|
||||
parameters: serde_json::from_value(json!({
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "WeatherQuery",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["location", "unit"]
|
||||
}))
|
||||
.unwrap(),
|
||||
}]
|
||||
);
|
||||
|
||||
let view = cx.update(|cx| tool.view(cx));
|
||||
let mut call = ToolFunctionCall {
|
||||
id: "the-id".to_string(),
|
||||
name: "get_cur".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
cx.update(|cx| {
|
||||
view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
|
||||
let task = cx.update(|cx| {
|
||||
registry.update_tool_call(
|
||||
&mut call,
|
||||
Some("rent_weather"),
|
||||
Some(r#"{"location": "San Francisco","#),
|
||||
cx,
|
||||
);
|
||||
registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
|
||||
registry.execute_tool_call(&mut call, cx).unwrap()
|
||||
});
|
||||
task.await.unwrap();
|
||||
|
||||
let finished = cx.update(|cx| view.execute(cx)).await;
|
||||
|
||||
assert!(finished.is_ok());
|
||||
match &call.state {
|
||||
ToolFunctionCallState::ExecutedTool(_view) => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue