Cleanup tool registry API surface (#11637)
Fast followups to #11629 Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
79b5556267
commit
9cef0ac869
3 changed files with 89 additions and 106 deletions
|
@ -11,8 +11,7 @@ use crate::ui::UserOrAssistant;
|
||||||
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
|
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use assistant_tooling::{
|
use assistant_tooling::{
|
||||||
tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
|
AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
|
||||||
UserAttachment,
|
|
||||||
};
|
};
|
||||||
use attachments::ActiveEditorAttachmentTool;
|
use attachments::ActiveEditorAttachmentTool;
|
||||||
use client::{proto, Client, UserStore};
|
use client::{proto, Client, UserStore};
|
||||||
|
@ -130,16 +129,13 @@ impl AssistantPanel {
|
||||||
|
|
||||||
let mut tool_registry = ToolRegistry::new();
|
let mut tool_registry = ToolRegistry::new();
|
||||||
tool_registry
|
tool_registry
|
||||||
.register(ProjectIndexTool::new(project_index.clone()), cx)
|
.register(ProjectIndexTool::new(project_index.clone()))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
tool_registry
|
tool_registry
|
||||||
.register(
|
.register(CreateBufferTool::new(workspace.clone(), project.clone()))
|
||||||
CreateBufferTool::new(workspace.clone(), project.clone()),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
tool_registry
|
tool_registry
|
||||||
.register(AnnotationTool::new(workspace.clone(), project.clone()), cx)
|
.register(AnnotationTool::new(workspace.clone(), project.clone()))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut attachment_registry = AttachmentRegistry::new();
|
let mut attachment_registry = AttachmentRegistry::new();
|
||||||
|
@ -588,9 +584,9 @@ impl AssistantChat {
|
||||||
cx.notify();
|
cx.notify();
|
||||||
} else {
|
} else {
|
||||||
if let Some(current_message) = messages.last_mut() {
|
if let Some(current_message) = messages.last_mut() {
|
||||||
for tool_call in current_message.tool_calls.iter() {
|
for tool_call in current_message.tool_calls.iter_mut() {
|
||||||
tool_tasks
|
tool_tasks
|
||||||
.extend(this.tool_registry.execute_tool_call(&tool_call, cx));
|
.extend(this.tool_registry.execute_tool_call(tool_call, cx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -847,7 +843,7 @@ impl AssistantChat {
|
||||||
let tools = message
|
let tools = message
|
||||||
.tool_calls
|
.tool_calls
|
||||||
.iter()
|
.iter()
|
||||||
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
|
.filter_map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
|
||||||
.collect::<Vec<AnyElement>>();
|
.collect::<Vec<AnyElement>>();
|
||||||
|
|
||||||
if !tools.is_empty() {
|
if !tools.is_empty() {
|
||||||
|
@ -856,7 +852,7 @@ impl AssistantChat {
|
||||||
}
|
}
|
||||||
|
|
||||||
if message_elements.is_empty() {
|
if message_elements.is_empty() {
|
||||||
message_elements.push(tool_running_placeholder());
|
message_elements.push(::ui::Label::new("Researching...").into_any_element())
|
||||||
}
|
}
|
||||||
|
|
||||||
div()
|
div()
|
||||||
|
|
|
@ -8,6 +8,6 @@ pub use attachment_registry::{
|
||||||
};
|
};
|
||||||
pub use project_context::ProjectContext;
|
pub use project_context::ProjectContext;
|
||||||
pub use tool_registry::{
|
pub use tool_registry::{
|
||||||
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState,
|
LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition, ToolOutput,
|
||||||
ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry,
|
ToolRegistry,
|
||||||
};
|
};
|
||||||
|
|
|
@ -9,10 +9,8 @@ use std::{
|
||||||
any::TypeId,
|
any::TypeId,
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
fmt::Display,
|
fmt::Display,
|
||||||
sync::{
|
mem,
|
||||||
atomic::{AtomicBool, Ordering::SeqCst},
|
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use ui::ViewContext;
|
use ui::ViewContext;
|
||||||
|
|
||||||
|
@ -29,7 +27,7 @@ pub struct ToolFunctionCall {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub enum ToolFunctionCallState {
|
enum ToolFunctionCallState {
|
||||||
#[default]
|
#[default]
|
||||||
Initializing,
|
Initializing,
|
||||||
NoSuchTool,
|
NoSuchTool,
|
||||||
|
@ -37,10 +35,10 @@ pub enum ToolFunctionCallState {
|
||||||
ExecutedTool(Box<dyn ToolView>),
|
ExecutedTool(Box<dyn ToolView>),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ToolView {
|
trait ToolView {
|
||||||
fn view(&self) -> AnyView;
|
fn view(&self) -> AnyView;
|
||||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||||
fn set_input(&self, input: &str, cx: &mut WindowContext);
|
fn try_set_input(&self, input: &str, cx: &mut WindowContext);
|
||||||
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
|
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
|
||||||
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
|
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
|
||||||
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
|
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
|
||||||
|
@ -48,14 +46,14 @@ pub trait ToolView {
|
||||||
|
|
||||||
#[derive(Default, Serialize, Deserialize)]
|
#[derive(Default, Serialize, Deserialize)]
|
||||||
pub struct SavedToolFunctionCall {
|
pub struct SavedToolFunctionCall {
|
||||||
pub id: String,
|
id: String,
|
||||||
pub name: String,
|
name: String,
|
||||||
pub arguments: String,
|
arguments: String,
|
||||||
pub state: SavedToolFunctionCallState,
|
state: SavedToolFunctionCallState,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Serialize, Deserialize)]
|
#[derive(Default, Serialize, Deserialize)]
|
||||||
pub enum SavedToolFunctionCallState {
|
enum SavedToolFunctionCallState {
|
||||||
#[default]
|
#[default]
|
||||||
Initializing,
|
Initializing,
|
||||||
NoSuchTool,
|
NoSuchTool,
|
||||||
|
@ -63,7 +61,7 @@ pub enum SavedToolFunctionCallState {
|
||||||
ExecutedTool(Box<RawValue>),
|
ExecutedTool(Box<RawValue>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
pub struct ToolFunctionDefinition {
|
pub struct ToolFunctionDefinition {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub description: String,
|
pub description: String,
|
||||||
|
@ -100,18 +98,6 @@ pub trait LanguageModelTool {
|
||||||
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
|
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool_running_placeholder() -> AnyElement {
|
|
||||||
ui::Label::new("Researching...").into_any_element()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unknown_tool_placeholder() -> AnyElement {
|
|
||||||
ui::Label::new("Unknown tool").into_any_element()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn no_such_tool_placeholder() -> AnyElement {
|
|
||||||
ui::Label::new("No such tool").into_any_element()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ToolOutput: Render {
|
pub trait ToolOutput: Render {
|
||||||
/// The input type that will be passed in to `execute` when the tool is called
|
/// The input type that will be passed in to `execute` when the tool is called
|
||||||
/// by the language model.
|
/// by the language model.
|
||||||
|
@ -172,11 +158,6 @@ impl ToolRegistry {
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
|
|
||||||
let tool = self.registered_tools.get(name)?;
|
|
||||||
Some((tool.build_view)(cx))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update_tool_call(
|
pub fn update_tool_call(
|
||||||
&self,
|
&self,
|
||||||
call: &mut ToolFunctionCall,
|
call: &mut ToolFunctionCall,
|
||||||
|
@ -189,7 +170,8 @@ impl ToolRegistry {
|
||||||
}
|
}
|
||||||
if let Some(arguments) = arguments {
|
if let Some(arguments) = arguments {
|
||||||
if call.arguments.is_empty() {
|
if call.arguments.is_empty() {
|
||||||
if let Some(view) = self.view_for_tool(&call.name, cx) {
|
if let Some(tool) = self.registered_tools.get(&call.name) {
|
||||||
|
let view = (tool.build_view)(cx);
|
||||||
call.state = ToolFunctionCallState::KnownTool(view);
|
call.state = ToolFunctionCallState::KnownTool(view);
|
||||||
} else {
|
} else {
|
||||||
call.state = ToolFunctionCallState::NoSuchTool;
|
call.state = ToolFunctionCallState::NoSuchTool;
|
||||||
|
@ -199,7 +181,7 @@ impl ToolRegistry {
|
||||||
|
|
||||||
if let ToolFunctionCallState::KnownTool(view) = &call.state {
|
if let ToolFunctionCallState::KnownTool(view) = &call.state {
|
||||||
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
|
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
|
||||||
view.set_input(&repaired_arguments, cx)
|
view.try_set_input(&repaired_arguments, cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -207,11 +189,13 @@ impl ToolRegistry {
|
||||||
|
|
||||||
pub fn execute_tool_call(
|
pub fn execute_tool_call(
|
||||||
&self,
|
&self,
|
||||||
tool_call: &ToolFunctionCall,
|
tool_call: &mut ToolFunctionCall,
|
||||||
cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
) -> Option<Task<Result<()>>> {
|
) -> Option<Task<Result<()>>> {
|
||||||
if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
|
if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
|
||||||
Some(view.execute(cx))
|
let task = view.execute(cx);
|
||||||
|
tool_call.state = ToolFunctionCallState::ExecutedTool(view);
|
||||||
|
Some(task)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -221,12 +205,14 @@ impl ToolRegistry {
|
||||||
&self,
|
&self,
|
||||||
tool_call: &ToolFunctionCall,
|
tool_call: &ToolFunctionCall,
|
||||||
_cx: &mut WindowContext,
|
_cx: &mut WindowContext,
|
||||||
) -> AnyElement {
|
) -> Option<AnyElement> {
|
||||||
match &tool_call.state {
|
match &tool_call.state {
|
||||||
ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
|
ToolFunctionCallState::NoSuchTool => {
|
||||||
ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
|
Some(ui::Label::new("No such tool").into_any_element())
|
||||||
|
}
|
||||||
|
ToolFunctionCallState::Initializing => None,
|
||||||
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
|
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
|
||||||
view.view().into_any_element()
|
Some(view.view().into_any_element())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -287,12 +273,12 @@ impl ToolRegistry {
|
||||||
SavedToolFunctionCallState::KnownTool => {
|
SavedToolFunctionCallState::KnownTool => {
|
||||||
log::error!("Deserialized tool that had not executed");
|
log::error!("Deserialized tool that had not executed");
|
||||||
let view = (tool.build_view)(cx);
|
let view = (tool.build_view)(cx);
|
||||||
view.set_input(&call.arguments, cx);
|
view.try_set_input(&call.arguments, cx);
|
||||||
ToolFunctionCallState::KnownTool(view)
|
ToolFunctionCallState::KnownTool(view)
|
||||||
}
|
}
|
||||||
SavedToolFunctionCallState::ExecutedTool(output) => {
|
SavedToolFunctionCallState::ExecutedTool(output) => {
|
||||||
let view = (tool.build_view)(cx);
|
let view = (tool.build_view)(cx);
|
||||||
view.set_input(&call.arguments, cx);
|
view.try_set_input(&call.arguments, cx);
|
||||||
view.deserialize_output(output, cx)?;
|
view.deserialize_output(output, cx)?;
|
||||||
ToolFunctionCallState::ExecutedTool(view)
|
ToolFunctionCallState::ExecutedTool(view)
|
||||||
}
|
}
|
||||||
|
@ -300,13 +286,8 @@ impl ToolRegistry {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register<T: 'static + LanguageModelTool>(
|
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
|
||||||
&mut self,
|
|
||||||
tool: T,
|
|
||||||
_cx: &mut WindowContext,
|
|
||||||
) -> Result<()> {
|
|
||||||
let name = tool.name();
|
let name = tool.name();
|
||||||
let tool = Arc::new(tool);
|
|
||||||
let registered_tool = RegisteredTool {
|
let registered_tool = RegisteredTool {
|
||||||
type_id: TypeId::of::<T>(),
|
type_id: TypeId::of::<T>(),
|
||||||
definition: tool.definition(),
|
definition: tool.definition(),
|
||||||
|
@ -332,7 +313,7 @@ impl<T: ToolOutput> ToolView for View<T> {
|
||||||
self.update(cx, |view, cx| view.generate(project, cx))
|
self.update(cx, |view, cx| view.generate(project, cx))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_input(&self, input: &str, cx: &mut WindowContext) {
|
fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
|
||||||
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
|
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
|
||||||
self.update(cx, |view, cx| {
|
self.update(cx, |view, cx| {
|
||||||
view.set_input(input, cx);
|
view.set_input(input, cx);
|
||||||
|
@ -372,7 +353,6 @@ mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
use gpui::{div, prelude::*, Render, TestAppContext};
|
use gpui::{div, prelude::*, Render, TestAppContext};
|
||||||
use gpui::{EmptyView, View};
|
use gpui::{EmptyView, View};
|
||||||
use schemars::schema_for;
|
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
@ -483,57 +463,64 @@ mod test {
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_openai_weather_example(cx: &mut TestAppContext) {
|
async fn test_openai_weather_example(cx: &mut TestAppContext) {
|
||||||
cx.background_executor.run_until_parked();
|
|
||||||
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
|
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
|
||||||
|
|
||||||
let tool = WeatherTool {
|
let mut registry = ToolRegistry::new();
|
||||||
current_weather: WeatherResult {
|
registry
|
||||||
location: "San Francisco".to_string(),
|
.register(WeatherTool {
|
||||||
temperature: 21.0,
|
current_weather: WeatherResult {
|
||||||
unit: "Celsius".to_string(),
|
location: "San Francisco".to_string(),
|
||||||
},
|
temperature: 21.0,
|
||||||
};
|
unit: "Celsius".to_string(),
|
||||||
|
|
||||||
let tools = vec![tool.definition()];
|
|
||||||
assert_eq!(tools.len(), 1);
|
|
||||||
|
|
||||||
let expected = ToolFunctionDefinition {
|
|
||||||
name: "get_current_weather".to_string(),
|
|
||||||
description: "Fetches the current weather for a given location.".to_string(),
|
|
||||||
parameters: schema_for!(WeatherQuery),
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(tools[0].name, expected.name);
|
|
||||||
assert_eq!(tools[0].description, expected.description);
|
|
||||||
|
|
||||||
let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
expected_schema,
|
|
||||||
json!({
|
|
||||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
|
||||||
"title": "WeatherQuery",
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"location": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"unit": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["location", "unit"]
|
|
||||||
})
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let definitions = registry.definitions();
|
||||||
|
assert_eq!(
|
||||||
|
definitions,
|
||||||
|
[ToolFunctionDefinition {
|
||||||
|
name: "get_current_weather".to_string(),
|
||||||
|
description: "Fetches the current weather for a given location.".to_string(),
|
||||||
|
parameters: serde_json::from_value(json!({
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||||
|
"title": "WeatherQuery",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "unit"]
|
||||||
|
}))
|
||||||
|
.unwrap(),
|
||||||
|
}]
|
||||||
);
|
);
|
||||||
|
|
||||||
let view = cx.update(|cx| tool.view(cx));
|
let mut call = ToolFunctionCall {
|
||||||
|
id: "the-id".to_string(),
|
||||||
|
name: "get_cur".to_string(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
cx.update(|cx| {
|
let task = cx.update(|cx| {
|
||||||
view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
|
registry.update_tool_call(
|
||||||
|
&mut call,
|
||||||
|
Some("rent_weather"),
|
||||||
|
Some(r#"{"location": "San Francisco","#),
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
|
||||||
|
registry.execute_tool_call(&mut call, cx).unwrap()
|
||||||
});
|
});
|
||||||
|
task.await.unwrap();
|
||||||
|
|
||||||
let finished = cx.update(|cx| view.execute(cx)).await;
|
match &call.state {
|
||||||
|
ToolFunctionCallState::ExecutedTool(_view) => {}
|
||||||
assert!(finished.is_ok());
|
_ => panic!(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue