Cleanup tool registry API surface (#11637)

Fast followups to #11629 

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Kyle Kelley 2024-05-09 16:43:27 -07:00 committed by GitHub
parent 79b5556267
commit 9cef0ac869
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 89 additions and 106 deletions

View file

@ -11,8 +11,7 @@ use crate::ui::UserOrAssistant;
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext}; use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use assistant_tooling::{ use assistant_tooling::{
tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
UserAttachment,
}; };
use attachments::ActiveEditorAttachmentTool; use attachments::ActiveEditorAttachmentTool;
use client::{proto, Client, UserStore}; use client::{proto, Client, UserStore};
@ -130,16 +129,13 @@ impl AssistantPanel {
let mut tool_registry = ToolRegistry::new(); let mut tool_registry = ToolRegistry::new();
tool_registry tool_registry
.register(ProjectIndexTool::new(project_index.clone()), cx) .register(ProjectIndexTool::new(project_index.clone()))
.unwrap(); .unwrap();
tool_registry tool_registry
.register( .register(CreateBufferTool::new(workspace.clone(), project.clone()))
CreateBufferTool::new(workspace.clone(), project.clone()),
cx,
)
.unwrap(); .unwrap();
tool_registry tool_registry
.register(AnnotationTool::new(workspace.clone(), project.clone()), cx) .register(AnnotationTool::new(workspace.clone(), project.clone()))
.unwrap(); .unwrap();
let mut attachment_registry = AttachmentRegistry::new(); let mut attachment_registry = AttachmentRegistry::new();
@ -588,9 +584,9 @@ impl AssistantChat {
cx.notify(); cx.notify();
} else { } else {
if let Some(current_message) = messages.last_mut() { if let Some(current_message) = messages.last_mut() {
for tool_call in current_message.tool_calls.iter() { for tool_call in current_message.tool_calls.iter_mut() {
tool_tasks tool_tasks
.extend(this.tool_registry.execute_tool_call(&tool_call, cx)); .extend(this.tool_registry.execute_tool_call(tool_call, cx));
} }
} }
} }
@ -847,7 +843,7 @@ impl AssistantChat {
let tools = message let tools = message
.tool_calls .tool_calls
.iter() .iter()
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx)) .filter_map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
.collect::<Vec<AnyElement>>(); .collect::<Vec<AnyElement>>();
if !tools.is_empty() { if !tools.is_empty() {
@ -856,7 +852,7 @@ impl AssistantChat {
} }
if message_elements.is_empty() { if message_elements.is_empty() {
message_elements.push(tool_running_placeholder()); message_elements.push(::ui::Label::new("Researching...").into_any_element())
} }
div() div()

View file

@ -8,6 +8,6 @@ pub use attachment_registry::{
}; };
pub use project_context::ProjectContext; pub use project_context::ProjectContext;
pub use tool_registry::{ pub use tool_registry::{
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState, LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition, ToolOutput,
ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry, ToolRegistry,
}; };

View file

@ -9,10 +9,8 @@ use std::{
any::TypeId, any::TypeId,
collections::HashMap, collections::HashMap,
fmt::Display, fmt::Display,
sync::{ mem,
atomic::{AtomicBool, Ordering::SeqCst}, sync::atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
}; };
use ui::ViewContext; use ui::ViewContext;
@ -29,7 +27,7 @@ pub struct ToolFunctionCall {
} }
#[derive(Default)] #[derive(Default)]
pub enum ToolFunctionCallState { enum ToolFunctionCallState {
#[default] #[default]
Initializing, Initializing,
NoSuchTool, NoSuchTool,
@ -37,10 +35,10 @@ pub enum ToolFunctionCallState {
ExecutedTool(Box<dyn ToolView>), ExecutedTool(Box<dyn ToolView>),
} }
pub trait ToolView { trait ToolView {
fn view(&self) -> AnyView; fn view(&self) -> AnyView;
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
fn set_input(&self, input: &str, cx: &mut WindowContext); fn try_set_input(&self, input: &str, cx: &mut WindowContext);
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>; fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>; fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>; fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
@ -48,14 +46,14 @@ pub trait ToolView {
#[derive(Default, Serialize, Deserialize)] #[derive(Default, Serialize, Deserialize)]
pub struct SavedToolFunctionCall { pub struct SavedToolFunctionCall {
pub id: String, id: String,
pub name: String, name: String,
pub arguments: String, arguments: String,
pub state: SavedToolFunctionCallState, state: SavedToolFunctionCallState,
} }
#[derive(Default, Serialize, Deserialize)] #[derive(Default, Serialize, Deserialize)]
pub enum SavedToolFunctionCallState { enum SavedToolFunctionCallState {
#[default] #[default]
Initializing, Initializing,
NoSuchTool, NoSuchTool,
@ -63,7 +61,7 @@ pub enum SavedToolFunctionCallState {
ExecutedTool(Box<RawValue>), ExecutedTool(Box<RawValue>),
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, PartialEq)]
pub struct ToolFunctionDefinition { pub struct ToolFunctionDefinition {
pub name: String, pub name: String,
pub description: String, pub description: String,
@ -100,18 +98,6 @@ pub trait LanguageModelTool {
fn view(&self, cx: &mut WindowContext) -> View<Self::View>; fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
} }
pub fn tool_running_placeholder() -> AnyElement {
ui::Label::new("Researching...").into_any_element()
}
pub fn unknown_tool_placeholder() -> AnyElement {
ui::Label::new("Unknown tool").into_any_element()
}
pub fn no_such_tool_placeholder() -> AnyElement {
ui::Label::new("No such tool").into_any_element()
}
pub trait ToolOutput: Render { pub trait ToolOutput: Render {
/// The input type that will be passed in to `execute` when the tool is called /// The input type that will be passed in to `execute` when the tool is called
/// by the language model. /// by the language model.
@ -172,11 +158,6 @@ impl ToolRegistry {
.collect() .collect()
} }
pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
let tool = self.registered_tools.get(name)?;
Some((tool.build_view)(cx))
}
pub fn update_tool_call( pub fn update_tool_call(
&self, &self,
call: &mut ToolFunctionCall, call: &mut ToolFunctionCall,
@ -189,7 +170,8 @@ impl ToolRegistry {
} }
if let Some(arguments) = arguments { if let Some(arguments) = arguments {
if call.arguments.is_empty() { if call.arguments.is_empty() {
if let Some(view) = self.view_for_tool(&call.name, cx) { if let Some(tool) = self.registered_tools.get(&call.name) {
let view = (tool.build_view)(cx);
call.state = ToolFunctionCallState::KnownTool(view); call.state = ToolFunctionCallState::KnownTool(view);
} else { } else {
call.state = ToolFunctionCallState::NoSuchTool; call.state = ToolFunctionCallState::NoSuchTool;
@ -199,7 +181,7 @@ impl ToolRegistry {
if let ToolFunctionCallState::KnownTool(view) = &call.state { if let ToolFunctionCallState::KnownTool(view) = &call.state {
if let Ok(repaired_arguments) = repair(call.arguments.clone()) { if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
view.set_input(&repaired_arguments, cx) view.try_set_input(&repaired_arguments, cx)
} }
} }
} }
@ -207,11 +189,13 @@ impl ToolRegistry {
pub fn execute_tool_call( pub fn execute_tool_call(
&self, &self,
tool_call: &ToolFunctionCall, tool_call: &mut ToolFunctionCall,
cx: &mut WindowContext, cx: &mut WindowContext,
) -> Option<Task<Result<()>>> { ) -> Option<Task<Result<()>>> {
if let ToolFunctionCallState::KnownTool(view) = &tool_call.state { if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
Some(view.execute(cx)) let task = view.execute(cx);
tool_call.state = ToolFunctionCallState::ExecutedTool(view);
Some(task)
} else { } else {
None None
} }
@ -221,12 +205,14 @@ impl ToolRegistry {
&self, &self,
tool_call: &ToolFunctionCall, tool_call: &ToolFunctionCall,
_cx: &mut WindowContext, _cx: &mut WindowContext,
) -> AnyElement { ) -> Option<AnyElement> {
match &tool_call.state { match &tool_call.state {
ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(), ToolFunctionCallState::NoSuchTool => {
ToolFunctionCallState::Initializing => unknown_tool_placeholder(), Some(ui::Label::new("No such tool").into_any_element())
}
ToolFunctionCallState::Initializing => None,
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
view.view().into_any_element() Some(view.view().into_any_element())
} }
} }
} }
@ -287,12 +273,12 @@ impl ToolRegistry {
SavedToolFunctionCallState::KnownTool => { SavedToolFunctionCallState::KnownTool => {
log::error!("Deserialized tool that had not executed"); log::error!("Deserialized tool that had not executed");
let view = (tool.build_view)(cx); let view = (tool.build_view)(cx);
view.set_input(&call.arguments, cx); view.try_set_input(&call.arguments, cx);
ToolFunctionCallState::KnownTool(view) ToolFunctionCallState::KnownTool(view)
} }
SavedToolFunctionCallState::ExecutedTool(output) => { SavedToolFunctionCallState::ExecutedTool(output) => {
let view = (tool.build_view)(cx); let view = (tool.build_view)(cx);
view.set_input(&call.arguments, cx); view.try_set_input(&call.arguments, cx);
view.deserialize_output(output, cx)?; view.deserialize_output(output, cx)?;
ToolFunctionCallState::ExecutedTool(view) ToolFunctionCallState::ExecutedTool(view)
} }
@ -300,13 +286,8 @@ impl ToolRegistry {
}) })
} }
pub fn register<T: 'static + LanguageModelTool>( pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
&mut self,
tool: T,
_cx: &mut WindowContext,
) -> Result<()> {
let name = tool.name(); let name = tool.name();
let tool = Arc::new(tool);
let registered_tool = RegisteredTool { let registered_tool = RegisteredTool {
type_id: TypeId::of::<T>(), type_id: TypeId::of::<T>(),
definition: tool.definition(), definition: tool.definition(),
@ -332,7 +313,7 @@ impl<T: ToolOutput> ToolView for View<T> {
self.update(cx, |view, cx| view.generate(project, cx)) self.update(cx, |view, cx| view.generate(project, cx))
} }
fn set_input(&self, input: &str, cx: &mut WindowContext) { fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
if let Ok(input) = serde_json::from_str::<T::Input>(input) { if let Ok(input) = serde_json::from_str::<T::Input>(input) {
self.update(cx, |view, cx| { self.update(cx, |view, cx| {
view.set_input(input, cx); view.set_input(input, cx);
@ -372,7 +353,6 @@ mod test {
use super::*; use super::*;
use gpui::{div, prelude::*, Render, TestAppContext}; use gpui::{div, prelude::*, Render, TestAppContext};
use gpui::{EmptyView, View}; use gpui::{EmptyView, View};
use schemars::schema_for;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
@ -483,57 +463,64 @@ mod test {
#[gpui::test] #[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) { async fn test_openai_weather_example(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked();
let (_, cx) = cx.add_window_view(|_cx| EmptyView); let (_, cx) = cx.add_window_view(|_cx| EmptyView);
let tool = WeatherTool { let mut registry = ToolRegistry::new();
current_weather: WeatherResult { registry
location: "San Francisco".to_string(), .register(WeatherTool {
temperature: 21.0, current_weather: WeatherResult {
unit: "Celsius".to_string(), location: "San Francisco".to_string(),
}, temperature: 21.0,
}; unit: "Celsius".to_string(),
let tools = vec![tool.definition()];
assert_eq!(tools.len(), 1);
let expected = ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
parameters: schema_for!(WeatherQuery),
};
assert_eq!(tools[0].name, expected.name);
assert_eq!(tools[0].description, expected.description);
let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
assert_eq!(
expected_schema,
json!({
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "WeatherQuery",
"type": "object",
"properties": {
"location": {
"type": "string"
},
"unit": {
"type": "string"
}
}, },
"required": ["location", "unit"]
}) })
.unwrap();
let definitions = registry.definitions();
assert_eq!(
definitions,
[ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
parameters: serde_json::from_value(json!({
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "WeatherQuery",
"type": "object",
"properties": {
"location": {
"type": "string"
},
"unit": {
"type": "string"
}
},
"required": ["location", "unit"]
}))
.unwrap(),
}]
); );
let view = cx.update(|cx| tool.view(cx)); let mut call = ToolFunctionCall {
id: "the-id".to_string(),
name: "get_cur".to_string(),
..Default::default()
};
cx.update(|cx| { let task = cx.update(|cx| {
view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx); registry.update_tool_call(
&mut call,
Some("rent_weather"),
Some(r#"{"location": "San Francisco","#),
cx,
);
registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
registry.execute_tool_call(&mut call, cx).unwrap()
}); });
task.await.unwrap();
let finished = cx.update(|cx| view.execute(cx)).await; match &call.state {
ToolFunctionCallState::ExecutedTool(_view) => {}
assert!(finished.is_ok()); _ => panic!(),
}
} }
} }