diff --git a/crates/assistant2/examples/chat-with-functions.rs b/crates/assistant2/examples/chat-with-functions.rs index 15d3c968a4..0a6ecbb02b 100644 --- a/crates/assistant2/examples/chat-with-functions.rs +++ b/crates/assistant2/examples/chat-with-functions.rs @@ -1,4 +1,5 @@ -use anyhow::Context as _; +/// This example creates a basic Chat UI with a function for rolling a die. +use anyhow::{Context as _, Result}; use assets::Assets; use assistant2::AssistantPanel; use assistant_tooling::{LanguageModelTool, ToolRegistry}; @@ -83,9 +84,32 @@ struct DiceRoll { rolls: Vec, } +pub struct DiceView { + result: Result, +} + +impl Render for DiceView { + fn render(&mut self, _cx: &mut ViewContext) -> impl IntoElement { + let output = match &self.result { + Ok(output) => output, + Err(_) => return "Somehow dice failed 🎲".into_any_element(), + }; + + h_flex() + .children( + output + .rolls + .iter() + .map(|roll| div().p_2().child(roll.render())), + ) + .into_any_element() + } +} + impl LanguageModelTool for RollDiceTool { type Input = DiceParams; type Output = DiceRoll; + type View = DiceView; fn name(&self) -> String { "roll_dice".to_string() @@ -110,23 +134,21 @@ impl LanguageModelTool for RollDiceTool { return Task::ready(Ok(DiceRoll { rolls })); } - fn render( - _tool_call_id: &str, - _input: &Self::Input, - output: &Self::Output, - _cx: &mut WindowContext, - ) -> gpui::AnyElement { - h_flex() - .children( - output - .rolls - .iter() - .map(|roll| div().p_2().child(roll.render())), - ) - .into_any_element() + fn new_view( + _tool_call_id: String, + _input: Self::Input, + result: Result, + cx: &mut WindowContext, + ) -> gpui::View { + cx.new_view(|_cx| DiceView { result }) } - fn format(_input: &Self::Input, output: &Self::Output) -> String { + fn format(_: &Self::Input, output: &Result) -> String { + let output = match output { + Ok(output) => output, + Err(_) => return "Somehow dice failed 🎲".to_string(), + }; + let mut result = String::new(); for roll in &output.rolls { let die = &roll.die; diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index b89291bd13..22d05a3fc5 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -322,9 +322,11 @@ impl AssistantChat { }; call_count += 1; + let messages = this.completion_messages(cx); + CompletionProvider::get(cx).complete( this.model.clone(), - this.completion_messages(cx), + messages, Vec::new(), 1.0, definitions, @@ -407,6 +409,10 @@ impl AssistantChat { } let tools = join_all(tool_tasks.into_iter()).await; + // If the WindowContext went away for any tool's view we don't include it + // especially since the below call would fail for the same reason. + let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect(); + this.update(cx, |this, cx| { if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) = this.messages.last_mut() @@ -561,10 +567,9 @@ impl AssistantChat { let result = &tool_call.result; let name = tool_call.name.clone(); match result { - Some(result) => div() - .p_2() - .child(result.render(&name, &tool_call.id, cx)) - .into_any(), + Some(result) => { + div().p_2().child(result.into_any_element(&name)).into_any() + } None => div() .p_2() .child(Label::new(name).color(Color::Modified)) @@ -577,7 +582,7 @@ impl AssistantChat { } } - fn completion_messages(&self, cx: &WindowContext) -> Vec { + fn completion_messages(&self, cx: &mut WindowContext) -> Vec { let mut completion_messages = Vec::new(); for message in &self.messages { diff --git a/crates/assistant2/src/tools.rs b/crates/assistant2/src/tools.rs index ffd5e42bfa..3e86e72168 100644 --- a/crates/assistant2/src/tools.rs +++ b/crates/assistant2/src/tools.rs @@ -1,10 +1,10 @@ use anyhow::Result; use assistant_tooling::LanguageModelTool; -use gpui::{prelude::*, AnyElement, AppContext, Model, Task}; +use gpui::{prelude::*, AppContext, Model, Task}; use project::Fs; use schemars::JsonSchema; use semantic_index::ProjectIndex; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use std::sync::Arc; use ui::{ div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString, @@ -14,11 +14,13 @@ use util::ResultExt as _; const DEFAULT_SEARCH_LIMIT: usize = 20; -#[derive(Serialize, Clone)] +#[derive(Clone)] pub struct CodebaseExcerpt { path: SharedString, text: SharedString, score: f32, + element_id: ElementId, + expanded: bool, } // Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model. @@ -32,6 +34,79 @@ pub struct CodebaseQuery { limit: Option, } +pub struct ProjectIndexView { + input: CodebaseQuery, + output: Result>, +} + +impl ProjectIndexView { + fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext) { + if let Ok(excerpts) = &mut self.output { + if let Some(excerpt) = excerpts + .iter_mut() + .find(|excerpt| excerpt.element_id == element_id) + { + excerpt.expanded = !excerpt.expanded; + cx.notify(); + } + } + } +} + +impl Render for ProjectIndexView { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let query = self.input.query.clone(); + + let result = &self.output; + + let excerpts = match result { + Err(err) => { + return div().child(Label::new(format!("Error: {}", err)).color(Color::Error)); + } + Ok(excerpts) => excerpts, + }; + + div() + .v_flex() + .gap_2() + .child( + div() + .p_2() + .rounded_md() + .bg(cx.theme().colors().editor_background) + .child( + h_flex() + .child(Label::new("Query: ").color(Color::Modified)) + .child(Label::new(query).color(Color::Muted)), + ), + ) + .children(excerpts.iter().map(|excerpt| { + let element_id = excerpt.element_id.clone(); + let expanded = excerpt.expanded; + + CollapsibleContainer::new(element_id.clone(), expanded) + .start_slot( + h_flex() + .gap_1() + .child(Icon::new(IconName::File).color(Color::Muted)) + .child(Label::new(excerpt.path.clone()).color(Color::Muted)), + ) + .on_click(cx.listener(move |this, _, cx| { + this.toggle_expanded(element_id.clone(), cx); + })) + .child( + div() + .p_2() + .rounded_md() + .bg(cx.theme().colors().editor_background) + .child( + excerpt.text.clone(), // todo!(): Show as an editor block + ), + ) + })) + } +} + pub struct ProjectIndexTool { project_index: Model, fs: Arc, @@ -47,6 +122,7 @@ impl ProjectIndexTool { impl LanguageModelTool for ProjectIndexTool { type Input = CodebaseQuery; type Output = Vec; + type View = ProjectIndexView; fn name(&self) -> String { "query_codebase".to_string() @@ -90,6 +166,8 @@ impl LanguageModelTool for ProjectIndexTool { } anyhow::Ok(CodebaseExcerpt { + element_id: ElementId::Name(nanoid::nanoid!().into()), + expanded: false, path: path.to_string_lossy().to_string().into(), text: SharedString::from(text[start..end].to_string()), score: result.score, @@ -106,71 +184,37 @@ impl LanguageModelTool for ProjectIndexTool { }) } - fn render( - _tool_call_id: &str, - input: &Self::Input, - excerpts: &Self::Output, + fn new_view( + _tool_call_id: String, + input: Self::Input, + output: Result, cx: &mut WindowContext, - ) -> AnyElement { - let query = input.query.clone(); - - div() - .v_flex() - .gap_2() - .child( - div() - .p_2() - .rounded_md() - .bg(cx.theme().colors().editor_background) - .child( - h_flex() - .child(Label::new("Query: ").color(Color::Modified)) - .child(Label::new(query).color(Color::Muted)), - ), - ) - .children(excerpts.iter().map(|excerpt| { - // This render doesn't have state/model, so we can't use the listener - // let expanded = excerpt.expanded; - // let element_id = excerpt.element_id.clone(); - let element_id = ElementId::Name(nanoid::nanoid!().into()); - let expanded = false; - - CollapsibleContainer::new(element_id.clone(), expanded) - .start_slot( - h_flex() - .gap_1() - .child(Icon::new(IconName::File).color(Color::Muted)) - .child(Label::new(excerpt.path.clone()).color(Color::Muted)), - ) - // .on_click(cx.listener(move |this, _, cx| { - // this.toggle_expanded(element_id.clone(), cx); - // })) - .child( - div() - .p_2() - .rounded_md() - .bg(cx.theme().colors().editor_background) - .child( - excerpt.text.clone(), // todo!(): Show as an editor block - ), - ) - })) - .into_any_element() + ) -> gpui::View { + cx.new_view(|_cx| ProjectIndexView { input, output }) } - fn format(_input: &Self::Input, excerpts: &Self::Output) -> String { - let mut body = "Semantic search results:\n".to_string(); + fn format(_input: &Self::Input, output: &Result) -> String { + match &output { + Ok(excerpts) => { + if excerpts.len() == 0 { + return "No results found".to_string(); + } - for excerpt in excerpts { - body.push_str("Excerpt from "); - body.push_str(excerpt.path.as_ref()); - body.push_str(", score "); - body.push_str(&excerpt.score.to_string()); - body.push_str(":\n"); - body.push_str("~~~\n"); - body.push_str(excerpt.text.as_ref()); - body.push_str("~~~\n"); + let mut body = "Semantic search results:\n".to_string(); + + for excerpt in excerpts { + body.push_str("Excerpt from "); + body.push_str(excerpt.path.as_ref()); + body.push_str(", score "); + body.push_str(&excerpt.score.to_string()); + body.push_str(":\n"); + body.push_str("~~~\n"); + body.push_str(excerpt.text.as_ref()); + body.push_str("~~~\n"); + } + body + } + Err(err) => format!("Error: {}", err), } - body } } diff --git a/crates/assistant_tooling/src/registry.rs b/crates/assistant_tooling/src/registry.rs index ac5930cac4..6a3bc313cd 100644 --- a/crates/assistant_tooling/src/registry.rs +++ b/crates/assistant_tooling/src/registry.rs @@ -1,13 +1,16 @@ use anyhow::{anyhow, Result}; -use gpui::{AnyElement, AppContext, Task, WindowContext}; -use std::{any::Any, collections::HashMap}; +use gpui::{Task, WindowContext}; +use std::collections::HashMap; use crate::tool::{ LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition, }; pub struct ToolRegistry { - tools: HashMap Task>>, + tools: HashMap< + String, + Box Task>>, + >, definitions: Vec, } @@ -24,77 +27,45 @@ impl ToolRegistry { } pub fn register(&mut self, tool: T) -> Result<()> { - fn render( - tool_call_id: &str, - input: &Box, - output: &Box, - cx: &mut WindowContext, - ) -> AnyElement { - T::render( - tool_call_id, - input.as_ref().downcast_ref::().unwrap(), - output.as_ref().downcast_ref::().unwrap(), - cx, - ) - } - - fn format( - input: &Box, - output: &Box, - ) -> String { - T::format( - input.as_ref().downcast_ref::().unwrap(), - output.as_ref().downcast_ref::().unwrap(), - ) - } - self.definitions.push(tool.definition()); let name = tool.name(); let previous = self.tools.insert( name.clone(), - Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| { - let name = tool_call.name.clone(); - let arguments = tool_call.arguments.clone(); - let id = tool_call.id.clone(); + // registry.call(tool_call, cx) + Box::new( + move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| { + let name = tool_call.name.clone(); + let arguments = tool_call.arguments.clone(); + let id = tool_call.id.clone(); - let Ok(input) = serde_json::from_str::(arguments.as_str()) else { - return Task::ready(ToolFunctionCall { - id, - name: name.clone(), - arguments, - result: Some(ToolFunctionCallResult::ParsingFailed), - }); - }; - - let result = tool.execute(&input, cx); - - cx.spawn(move |_cx| async move { - match result.await { - Ok(result) => { - let result: T::Output = result; - ToolFunctionCall { - id, - name: name.clone(), - arguments, - result: Some(ToolFunctionCallResult::Finished { - input: Box::new(input), - output: Box::new(result), - render_fn: render::, - format_fn: format::, - }), - } - } - Err(_error) => ToolFunctionCall { + let Ok(input) = serde_json::from_str::(arguments.as_str()) else { + return Task::ready(Ok(ToolFunctionCall { id, name: name.clone(), arguments, - result: Some(ToolFunctionCallResult::ExecutionFailed { - input: Box::new(input), + result: Some(ToolFunctionCallResult::ParsingFailed), + })); + }; + + let result = tool.execute(&input, cx); + + cx.spawn(move |mut cx| async move { + let result: Result = result.await; + let for_model = T::format(&input, &result); + let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?; + + Ok(ToolFunctionCall { + id, + name: name.clone(), + arguments, + result: Some(ToolFunctionCallResult::Finished { + view: view.into(), + for_model, }), - }, - } - }) - }), + }) + }) + }, + ), ); if previous.is_some() { @@ -104,7 +75,12 @@ impl ToolRegistry { Ok(()) } - pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task { + /// Task yields an error if the window for the given WindowContext is closed before the task completes. + pub fn call( + &self, + tool_call: &ToolFunctionCall, + cx: &mut WindowContext, + ) -> Task> { let name = tool_call.name.clone(); let arguments = tool_call.arguments.clone(); let id = tool_call.id.clone(); @@ -113,12 +89,12 @@ impl ToolRegistry { Some(tool) => tool, None => { let name = name.clone(); - return Task::ready(ToolFunctionCall { + return Task::ready(Ok(ToolFunctionCall { id, name: name.clone(), arguments, result: Some(ToolFunctionCallResult::NoSuchTool), - }); + })); } }; @@ -128,12 +104,10 @@ impl ToolRegistry { #[cfg(test)] mod test { - use super::*; - + use gpui::View; + use gpui::{div, prelude::*, Render, TestAppContext}; use schemars::schema_for; - - use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -155,9 +129,20 @@ mod test { unit: String, } + struct WeatherView { + result: WeatherResult, + } + + impl Render for WeatherView { + fn render(&mut self, _cx: &mut gpui::ViewContext) -> impl IntoElement { + div().child(format!("temperature: {}", self.result.temperature)) + } + } + impl LanguageModelTool for WeatherTool { type Input = WeatherQuery; type Output = WeatherResult; + type View = WeatherView; fn name(&self) -> String { "get_current_weather".to_string() @@ -167,7 +152,11 @@ mod test { "Fetches the current weather for a given location.".to_string() } - fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task> { + fn execute( + &self, + input: &Self::Input, + _cx: &gpui::AppContext, + ) -> Task> { let _location = input.location.clone(); let _unit = input.unit.clone(); @@ -176,25 +165,20 @@ mod test { Task::ready(Ok(weather)) } - fn render( - _tool_call_id: &str, - _input: &Self::Input, - output: &Self::Output, - _cx: &mut WindowContext, - ) -> AnyElement { - div() - .child(format!( - "The current temperature in {} is {} {}", - output.location, output.temperature, output.unit - )) - .into_any() + fn new_view( + _tool_call_id: String, + _input: Self::Input, + result: Result, + cx: &mut WindowContext, + ) -> View { + cx.new_view(|_cx| { + let result = result.unwrap(); + WeatherView { result } + }) } - fn format(_input: &Self::Input, output: &Self::Output) -> String { - format!( - "The current temperature in {} is {} {}", - output.location, output.temperature, output.unit - ) + fn format(_: &Self::Input, output: &Result) -> String { + serde_json::to_string(&output.as_ref().unwrap()).unwrap() } } @@ -214,20 +198,20 @@ mod test { registry.register(tool).unwrap(); - let _result = cx - .update(|cx| { - registry.call( - &ToolFunctionCall { - name: "get_current_weather".to_string(), - arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"# - .to_string(), - id: "test-123".to_string(), - result: None, - }, - cx, - ) - }) - .await; + // let _result = cx + // .update(|cx| { + // registry.call( + // &ToolFunctionCall { + // name: "get_current_weather".to_string(), + // arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"# + // .to_string(), + // id: "test-123".to_string(), + // result: None, + // }, + // cx, + // ) + // }) + // .await; // assert!(result.is_ok()); // let result = result.unwrap(); diff --git a/crates/assistant_tooling/src/tool.rs b/crates/assistant_tooling/src/tool.rs index a3b021a04e..8a1ffcf9d4 100644 --- a/crates/assistant_tooling/src/tool.rs +++ b/crates/assistant_tooling/src/tool.rs @@ -1,11 +1,8 @@ use anyhow::Result; -use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext}; +use gpui::{AnyElement, AnyView, AppContext, IntoElement as _, Render, Task, View, WindowContext}; use schemars::{schema::RootSchema, schema_for, JsonSchema}; use serde::Deserialize; -use std::{ - any::Any, - fmt::{Debug, Display}, -}; +use std::fmt::Display; #[derive(Default, Deserialize)] pub struct ToolFunctionCall { @@ -19,71 +16,29 @@ pub struct ToolFunctionCall { pub enum ToolFunctionCallResult { NoSuchTool, ParsingFailed, - ExecutionFailed { - input: Box, - }, - Finished { - input: Box, - output: Box, - render_fn: fn( - // tool_call_id - &str, - // LanguageModelTool::Input - &Box, - // LanguageModelTool::Output - &Box, - &mut WindowContext, - ) -> AnyElement, - format_fn: fn( - // LanguageModelTool::Input - &Box, - // LanguageModelTool::Output - &Box, - ) -> String, - }, + Finished { for_model: String, view: AnyView }, } impl ToolFunctionCallResult { - pub fn render( - &self, - tool_name: &str, - tool_call_id: &str, - cx: &mut WindowContext, - ) -> AnyElement { + pub fn format(&self, name: &String) -> String { match self { - ToolFunctionCallResult::NoSuchTool => { - div().child(format!("no such tool {tool_name}")).into_any() + ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"), + ToolFunctionCallResult::ParsingFailed => { + format!("Unable to parse arguments for {name}") } - ToolFunctionCallResult::ParsingFailed => div() - .child(format!("failed to parse input for tool {tool_name}")) - .into_any(), - ToolFunctionCallResult::ExecutionFailed { .. } => div() - .child(format!("failed to execute tool {tool_name}")) - .into_any(), - ToolFunctionCallResult::Finished { - input, - output, - render_fn, - .. - } => render_fn(tool_call_id, input, output, cx), + ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(), } } - pub fn format(&self, tool: &str) -> String { + pub fn into_any_element(&self, name: &String) -> AnyElement { match self { - ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"), + ToolFunctionCallResult::NoSuchTool => { + format!("Language Model attempted to call {name}").into_any_element() + } ToolFunctionCallResult::ParsingFailed => { - format!("failed to parse input for tool {tool}") + format!("Language Model called {name} with bad arguments").into_any_element() } - ToolFunctionCallResult::ExecutionFailed { input: _input } => { - format!("failed to execute tool {tool}") - } - ToolFunctionCallResult::Finished { - input, - output, - format_fn, - .. - } => format_fn(input, output), + ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(), } } } @@ -105,19 +60,6 @@ impl Display for ToolFunctionDefinition { } } -impl Debug for ToolFunctionDefinition { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let schema = serde_json::to_string(&self.parameters).ok(); - let schema = schema.unwrap_or("None".to_string()); - - f.debug_struct("ToolFunctionDefinition") - .field("name", &self.name) - .field("description", &self.description) - .field("parameters", &schema) - .finish() - } -} - pub trait LanguageModelTool { /// The input type that will be passed in to `execute` when the tool is called /// by the language model. @@ -126,6 +68,8 @@ pub trait LanguageModelTool { /// The output returned by executing the tool. type Output: 'static; + type View: Render; + /// The name of the tool is exposed to the language model to allow /// the model to pick which tools to use. As this name is used to /// identify the tool within a tool registry, it should be unique. @@ -149,12 +93,12 @@ pub trait LanguageModelTool { /// Execute the tool fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task>; - fn render( - tool_call_id: &str, - input: &Self::Input, - output: &Self::Output, - cx: &mut WindowContext, - ) -> AnyElement; + fn format(input: &Self::Input, output: &Result) -> String; - fn format(input: &Self::Input, output: &Self::Output) -> String; + fn new_view( + tool_call_id: String, + input: Self::Input, + output: Result, + cx: &mut WindowContext, + ) -> View; }