Cleanup tool registry API surface (#11637)

Fast followups to #11629 

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Kyle Kelley 2024-05-09 16:43:27 -07:00 committed by GitHub
parent 79b5556267
commit 9cef0ac869
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 89 additions and 106 deletions

View file

@ -11,8 +11,7 @@ use crate::ui::UserOrAssistant;
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
use anyhow::{Context, Result};
use assistant_tooling::{
tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
UserAttachment,
AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
};
use attachments::ActiveEditorAttachmentTool;
use client::{proto, Client, UserStore};
@ -130,16 +129,13 @@ impl AssistantPanel {
let mut tool_registry = ToolRegistry::new();
tool_registry
.register(ProjectIndexTool::new(project_index.clone()), cx)
.register(ProjectIndexTool::new(project_index.clone()))
.unwrap();
tool_registry
.register(
CreateBufferTool::new(workspace.clone(), project.clone()),
cx,
)
.register(CreateBufferTool::new(workspace.clone(), project.clone()))
.unwrap();
tool_registry
.register(AnnotationTool::new(workspace.clone(), project.clone()), cx)
.register(AnnotationTool::new(workspace.clone(), project.clone()))
.unwrap();
let mut attachment_registry = AttachmentRegistry::new();
@ -588,9 +584,9 @@ impl AssistantChat {
cx.notify();
} else {
if let Some(current_message) = messages.last_mut() {
for tool_call in current_message.tool_calls.iter() {
for tool_call in current_message.tool_calls.iter_mut() {
tool_tasks
.extend(this.tool_registry.execute_tool_call(&tool_call, cx));
.extend(this.tool_registry.execute_tool_call(tool_call, cx));
}
}
}
@ -847,7 +843,7 @@ impl AssistantChat {
let tools = message
.tool_calls
.iter()
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
.filter_map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
.collect::<Vec<AnyElement>>();
if !tools.is_empty() {
@ -856,7 +852,7 @@ impl AssistantChat {
}
if message_elements.is_empty() {
message_elements.push(tool_running_placeholder());
message_elements.push(::ui::Label::new("Researching...").into_any_element())
}
div()

View file

@ -8,6 +8,6 @@ pub use attachment_registry::{
};
pub use project_context::ProjectContext;
pub use tool_registry::{
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState,
ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry,
LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition, ToolOutput,
ToolRegistry,
};

View file

@ -9,10 +9,8 @@ use std::{
any::TypeId,
collections::HashMap,
fmt::Display,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
mem,
sync::atomic::{AtomicBool, Ordering::SeqCst},
};
use ui::ViewContext;
@ -29,7 +27,7 @@ pub struct ToolFunctionCall {
}
#[derive(Default)]
pub enum ToolFunctionCallState {
enum ToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
@ -37,10 +35,10 @@ pub enum ToolFunctionCallState {
ExecutedTool(Box<dyn ToolView>),
}
pub trait ToolView {
trait ToolView {
fn view(&self) -> AnyView;
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
fn set_input(&self, input: &str, cx: &mut WindowContext);
fn try_set_input(&self, input: &str, cx: &mut WindowContext);
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
@ -48,14 +46,14 @@ pub trait ToolView {
#[derive(Default, Serialize, Deserialize)]
pub struct SavedToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
pub state: SavedToolFunctionCallState,
id: String,
name: String,
arguments: String,
state: SavedToolFunctionCallState,
}
#[derive(Default, Serialize, Deserialize)]
pub enum SavedToolFunctionCallState {
enum SavedToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
@ -63,7 +61,7 @@ pub enum SavedToolFunctionCallState {
ExecutedTool(Box<RawValue>),
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
@ -100,18 +98,6 @@ pub trait LanguageModelTool {
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
}
pub fn tool_running_placeholder() -> AnyElement {
ui::Label::new("Researching...").into_any_element()
}
pub fn unknown_tool_placeholder() -> AnyElement {
ui::Label::new("Unknown tool").into_any_element()
}
pub fn no_such_tool_placeholder() -> AnyElement {
ui::Label::new("No such tool").into_any_element()
}
pub trait ToolOutput: Render {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
@ -172,11 +158,6 @@ impl ToolRegistry {
.collect()
}
pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
let tool = self.registered_tools.get(name)?;
Some((tool.build_view)(cx))
}
pub fn update_tool_call(
&self,
call: &mut ToolFunctionCall,
@ -189,7 +170,8 @@ impl ToolRegistry {
}
if let Some(arguments) = arguments {
if call.arguments.is_empty() {
if let Some(view) = self.view_for_tool(&call.name, cx) {
if let Some(tool) = self.registered_tools.get(&call.name) {
let view = (tool.build_view)(cx);
call.state = ToolFunctionCallState::KnownTool(view);
} else {
call.state = ToolFunctionCallState::NoSuchTool;
@ -199,7 +181,7 @@ impl ToolRegistry {
if let ToolFunctionCallState::KnownTool(view) = &call.state {
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
view.set_input(&repaired_arguments, cx)
view.try_set_input(&repaired_arguments, cx)
}
}
}
@ -207,11 +189,13 @@ impl ToolRegistry {
pub fn execute_tool_call(
&self,
tool_call: &ToolFunctionCall,
tool_call: &mut ToolFunctionCall,
cx: &mut WindowContext,
) -> Option<Task<Result<()>>> {
if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
Some(view.execute(cx))
if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
let task = view.execute(cx);
tool_call.state = ToolFunctionCallState::ExecutedTool(view);
Some(task)
} else {
None
}
@ -221,12 +205,14 @@ impl ToolRegistry {
&self,
tool_call: &ToolFunctionCall,
_cx: &mut WindowContext,
) -> AnyElement {
) -> Option<AnyElement> {
match &tool_call.state {
ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
ToolFunctionCallState::NoSuchTool => {
Some(ui::Label::new("No such tool").into_any_element())
}
ToolFunctionCallState::Initializing => None,
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
view.view().into_any_element()
Some(view.view().into_any_element())
}
}
}
@ -287,12 +273,12 @@ impl ToolRegistry {
SavedToolFunctionCallState::KnownTool => {
log::error!("Deserialized tool that had not executed");
let view = (tool.build_view)(cx);
view.set_input(&call.arguments, cx);
view.try_set_input(&call.arguments, cx);
ToolFunctionCallState::KnownTool(view)
}
SavedToolFunctionCallState::ExecutedTool(output) => {
let view = (tool.build_view)(cx);
view.set_input(&call.arguments, cx);
view.try_set_input(&call.arguments, cx);
view.deserialize_output(output, cx)?;
ToolFunctionCallState::ExecutedTool(view)
}
@ -300,13 +286,8 @@ impl ToolRegistry {
})
}
pub fn register<T: 'static + LanguageModelTool>(
&mut self,
tool: T,
_cx: &mut WindowContext,
) -> Result<()> {
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
let name = tool.name();
let tool = Arc::new(tool);
let registered_tool = RegisteredTool {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
@ -332,7 +313,7 @@ impl<T: ToolOutput> ToolView for View<T> {
self.update(cx, |view, cx| view.generate(project, cx))
}
fn set_input(&self, input: &str, cx: &mut WindowContext) {
fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
self.update(cx, |view, cx| {
view.set_input(input, cx);
@ -372,7 +353,6 @@ mod test {
use super::*;
use gpui::{div, prelude::*, Render, TestAppContext};
use gpui::{EmptyView, View};
use schemars::schema_for;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
@ -483,57 +463,64 @@ mod test {
#[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked();
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
let tool = WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
};
let tools = vec![tool.definition()];
assert_eq!(tools.len(), 1);
let expected = ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
parameters: schema_for!(WeatherQuery),
};
assert_eq!(tools[0].name, expected.name);
assert_eq!(tools[0].description, expected.description);
let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
assert_eq!(
expected_schema,
json!({
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "WeatherQuery",
"type": "object",
"properties": {
"location": {
"type": "string"
},
"unit": {
"type": "string"
}
let mut registry = ToolRegistry::new();
registry
.register(WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
"required": ["location", "unit"]
})
.unwrap();
let definitions = registry.definitions();
assert_eq!(
definitions,
[ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
parameters: serde_json::from_value(json!({
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "WeatherQuery",
"type": "object",
"properties": {
"location": {
"type": "string"
},
"unit": {
"type": "string"
}
},
"required": ["location", "unit"]
}))
.unwrap(),
}]
);
let view = cx.update(|cx| tool.view(cx));
let mut call = ToolFunctionCall {
id: "the-id".to_string(),
name: "get_cur".to_string(),
..Default::default()
};
cx.update(|cx| {
view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
let task = cx.update(|cx| {
registry.update_tool_call(
&mut call,
Some("rent_weather"),
Some(r#"{"location": "San Francisco","#),
cx,
);
registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
registry.execute_tool_call(&mut call, cx).unwrap()
});
task.await.unwrap();
let finished = cx.update(|cx| view.execute(cx)).await;
assert!(finished.is_ok());
match &call.state {
ToolFunctionCallState::ExecutedTool(_view) => {}
_ => panic!(),
}
}
}