diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 35100f9882..51161b7490 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -11,8 +11,7 @@ use crate::ui::UserOrAssistant; use ::ui::{div, prelude::*, Color, Tooltip, ViewContext}; use anyhow::{Context, Result}; use assistant_tooling::{ - tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, - UserAttachment, + AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment, }; use attachments::ActiveEditorAttachmentTool; use client::{proto, Client, UserStore}; @@ -130,16 +129,13 @@ impl AssistantPanel { let mut tool_registry = ToolRegistry::new(); tool_registry - .register(ProjectIndexTool::new(project_index.clone()), cx) + .register(ProjectIndexTool::new(project_index.clone())) .unwrap(); tool_registry - .register( - CreateBufferTool::new(workspace.clone(), project.clone()), - cx, - ) + .register(CreateBufferTool::new(workspace.clone(), project.clone())) .unwrap(); tool_registry - .register(AnnotationTool::new(workspace.clone(), project.clone()), cx) + .register(AnnotationTool::new(workspace.clone(), project.clone())) .unwrap(); let mut attachment_registry = AttachmentRegistry::new(); @@ -588,9 +584,9 @@ impl AssistantChat { cx.notify(); } else { if let Some(current_message) = messages.last_mut() { - for tool_call in current_message.tool_calls.iter() { + for tool_call in current_message.tool_calls.iter_mut() { tool_tasks - .extend(this.tool_registry.execute_tool_call(&tool_call, cx)); + .extend(this.tool_registry.execute_tool_call(tool_call, cx)); } } } @@ -847,7 +843,7 @@ impl AssistantChat { let tools = message .tool_calls .iter() - .map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx)) + .filter_map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx)) .collect::>(); if !tools.is_empty() { @@ -856,7 +852,7 @@ impl AssistantChat { } if message_elements.is_empty() { - message_elements.push(tool_running_placeholder()); + message_elements.push(::ui::Label::new("Researching...").into_any_element()) } div() diff --git a/crates/assistant_tooling/src/assistant_tooling.rs b/crates/assistant_tooling/src/assistant_tooling.rs index dd4dac39e9..9a45ad9b1d 100644 --- a/crates/assistant_tooling/src/assistant_tooling.rs +++ b/crates/assistant_tooling/src/assistant_tooling.rs @@ -8,6 +8,6 @@ pub use attachment_registry::{ }; pub use project_context::ProjectContext; pub use tool_registry::{ - tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState, - ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry, + LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition, ToolOutput, + ToolRegistry, }; diff --git a/crates/assistant_tooling/src/tool_registry.rs b/crates/assistant_tooling/src/tool_registry.rs index 7f0a8fb296..98793e4b8f 100644 --- a/crates/assistant_tooling/src/tool_registry.rs +++ b/crates/assistant_tooling/src/tool_registry.rs @@ -9,10 +9,8 @@ use std::{ any::TypeId, collections::HashMap, fmt::Display, - sync::{ - atomic::{AtomicBool, Ordering::SeqCst}, - Arc, - }, + mem, + sync::atomic::{AtomicBool, Ordering::SeqCst}, }; use ui::ViewContext; @@ -29,7 +27,7 @@ pub struct ToolFunctionCall { } #[derive(Default)] -pub enum ToolFunctionCallState { +enum ToolFunctionCallState { #[default] Initializing, NoSuchTool, @@ -37,10 +35,10 @@ pub enum ToolFunctionCallState { ExecutedTool(Box), } -pub trait ToolView { +trait ToolView { fn view(&self) -> AnyView; fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; - fn set_input(&self, input: &str, cx: &mut WindowContext); + fn try_set_input(&self, input: &str, cx: &mut WindowContext); fn execute(&self, cx: &mut WindowContext) -> Task>; fn serialize_output(&self, cx: &mut WindowContext) -> Result>; fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>; @@ -48,14 +46,14 @@ pub trait ToolView { #[derive(Default, Serialize, Deserialize)] pub struct SavedToolFunctionCall { - pub id: String, - pub name: String, - pub arguments: String, - pub state: SavedToolFunctionCallState, + id: String, + name: String, + arguments: String, + state: SavedToolFunctionCallState, } #[derive(Default, Serialize, Deserialize)] -pub enum SavedToolFunctionCallState { +enum SavedToolFunctionCallState { #[default] Initializing, NoSuchTool, @@ -63,7 +61,7 @@ pub enum SavedToolFunctionCallState { ExecutedTool(Box), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct ToolFunctionDefinition { pub name: String, pub description: String, @@ -100,18 +98,6 @@ pub trait LanguageModelTool { fn view(&self, cx: &mut WindowContext) -> View; } -pub fn tool_running_placeholder() -> AnyElement { - ui::Label::new("Researching...").into_any_element() -} - -pub fn unknown_tool_placeholder() -> AnyElement { - ui::Label::new("Unknown tool").into_any_element() -} - -pub fn no_such_tool_placeholder() -> AnyElement { - ui::Label::new("No such tool").into_any_element() -} - pub trait ToolOutput: Render { /// The input type that will be passed in to `execute` when the tool is called /// by the language model. @@ -172,11 +158,6 @@ impl ToolRegistry { .collect() } - pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option> { - let tool = self.registered_tools.get(name)?; - Some((tool.build_view)(cx)) - } - pub fn update_tool_call( &self, call: &mut ToolFunctionCall, @@ -189,7 +170,8 @@ impl ToolRegistry { } if let Some(arguments) = arguments { if call.arguments.is_empty() { - if let Some(view) = self.view_for_tool(&call.name, cx) { + if let Some(tool) = self.registered_tools.get(&call.name) { + let view = (tool.build_view)(cx); call.state = ToolFunctionCallState::KnownTool(view); } else { call.state = ToolFunctionCallState::NoSuchTool; @@ -199,7 +181,7 @@ impl ToolRegistry { if let ToolFunctionCallState::KnownTool(view) = &call.state { if let Ok(repaired_arguments) = repair(call.arguments.clone()) { - view.set_input(&repaired_arguments, cx) + view.try_set_input(&repaired_arguments, cx) } } } @@ -207,11 +189,13 @@ impl ToolRegistry { pub fn execute_tool_call( &self, - tool_call: &ToolFunctionCall, + tool_call: &mut ToolFunctionCall, cx: &mut WindowContext, ) -> Option>> { - if let ToolFunctionCallState::KnownTool(view) = &tool_call.state { - Some(view.execute(cx)) + if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) { + let task = view.execute(cx); + tool_call.state = ToolFunctionCallState::ExecutedTool(view); + Some(task) } else { None } @@ -221,12 +205,14 @@ impl ToolRegistry { &self, tool_call: &ToolFunctionCall, _cx: &mut WindowContext, - ) -> AnyElement { + ) -> Option { match &tool_call.state { - ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(), - ToolFunctionCallState::Initializing => unknown_tool_placeholder(), + ToolFunctionCallState::NoSuchTool => { + Some(ui::Label::new("No such tool").into_any_element()) + } + ToolFunctionCallState::Initializing => None, ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { - view.view().into_any_element() + Some(view.view().into_any_element()) } } } @@ -287,12 +273,12 @@ impl ToolRegistry { SavedToolFunctionCallState::KnownTool => { log::error!("Deserialized tool that had not executed"); let view = (tool.build_view)(cx); - view.set_input(&call.arguments, cx); + view.try_set_input(&call.arguments, cx); ToolFunctionCallState::KnownTool(view) } SavedToolFunctionCallState::ExecutedTool(output) => { let view = (tool.build_view)(cx); - view.set_input(&call.arguments, cx); + view.try_set_input(&call.arguments, cx); view.deserialize_output(output, cx)?; ToolFunctionCallState::ExecutedTool(view) } @@ -300,13 +286,8 @@ impl ToolRegistry { }) } - pub fn register( - &mut self, - tool: T, - _cx: &mut WindowContext, - ) -> Result<()> { + pub fn register(&mut self, tool: T) -> Result<()> { let name = tool.name(); - let tool = Arc::new(tool); let registered_tool = RegisteredTool { type_id: TypeId::of::(), definition: tool.definition(), @@ -332,7 +313,7 @@ impl ToolView for View { self.update(cx, |view, cx| view.generate(project, cx)) } - fn set_input(&self, input: &str, cx: &mut WindowContext) { + fn try_set_input(&self, input: &str, cx: &mut WindowContext) { if let Ok(input) = serde_json::from_str::(input) { self.update(cx, |view, cx| { view.set_input(input, cx); @@ -372,7 +353,6 @@ mod test { use super::*; use gpui::{div, prelude::*, Render, TestAppContext}; use gpui::{EmptyView, View}; - use schemars::schema_for; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -483,57 +463,64 @@ mod test { #[gpui::test] async fn test_openai_weather_example(cx: &mut TestAppContext) { - cx.background_executor.run_until_parked(); let (_, cx) = cx.add_window_view(|_cx| EmptyView); - let tool = WeatherTool { - current_weather: WeatherResult { - location: "San Francisco".to_string(), - temperature: 21.0, - unit: "Celsius".to_string(), - }, - }; - - let tools = vec![tool.definition()]; - assert_eq!(tools.len(), 1); - - let expected = ToolFunctionDefinition { - name: "get_current_weather".to_string(), - description: "Fetches the current weather for a given location.".to_string(), - parameters: schema_for!(WeatherQuery), - }; - - assert_eq!(tools[0].name, expected.name); - assert_eq!(tools[0].description, expected.description); - - let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap(); - - assert_eq!( - expected_schema, - json!({ - "$schema": "http://json-schema.org/draft-07/schema#", - "title": "WeatherQuery", - "type": "object", - "properties": { - "location": { - "type": "string" - }, - "unit": { - "type": "string" - } + let mut registry = ToolRegistry::new(); + registry + .register(WeatherTool { + current_weather: WeatherResult { + location: "San Francisco".to_string(), + temperature: 21.0, + unit: "Celsius".to_string(), }, - "required": ["location", "unit"] }) + .unwrap(); + + let definitions = registry.definitions(); + assert_eq!( + definitions, + [ToolFunctionDefinition { + name: "get_current_weather".to_string(), + description: "Fetches the current weather for a given location.".to_string(), + parameters: serde_json::from_value(json!({ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "WeatherQuery", + "type": "object", + "properties": { + "location": { + "type": "string" + }, + "unit": { + "type": "string" + } + }, + "required": ["location", "unit"] + })) + .unwrap(), + }] ); - let view = cx.update(|cx| tool.view(cx)); + let mut call = ToolFunctionCall { + id: "the-id".to_string(), + name: "get_cur".to_string(), + ..Default::default() + }; - cx.update(|cx| { - view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx); + let task = cx.update(|cx| { + registry.update_tool_call( + &mut call, + Some("rent_weather"), + Some(r#"{"location": "San Francisco","#), + cx, + ); + registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx); + registry.execute_tool_call(&mut call, cx).unwrap() }); + task.await.unwrap(); - let finished = cx.update(|cx| view.execute(cx)).await; - - assert!(finished.is_ok()); + match &call.state { + ToolFunctionCallState::ExecutedTool(_view) => {} + _ => panic!(), + } } }