New revision of the Assistant Panel (#10870)

This is a crate only addition of a new version of the AssistantPanel.
We'll be putting this behind a feature flag while we iron out the new
experience.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Conrad Irwin <conrad@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
Co-authored-by: Antonio Scandurra <antonio@zed.dev>
Co-authored-by: Nate Butler <nate@zed.dev>
Co-authored-by: Nate Butler <iamnbutler@gmail.com>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Kyle Kelley 2024-04-23 16:23:26 -07:00 committed by GitHub
parent e0c83a1d32
commit 68a1ad89bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 2989 additions and 262 deletions

View file

@ -0,0 +1,5 @@
pub mod registry;
pub mod tool;
pub use crate::registry::ToolRegistry;
pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition};

View file

@ -0,0 +1,298 @@
use anyhow::{anyhow, Result};
use gpui::{AnyElement, AppContext, Task, WindowContext};
use std::{any::Any, collections::HashMap};
use crate::tool::{
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
};
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn Fn(&ToolFunctionCall, &AppContext) -> Task<ToolFunctionCall>>>,
definitions: Vec<ToolFunctionDefinition>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
definitions: Vec::new(),
}
}
pub fn definitions(&self) -> &[ToolFunctionDefinition] {
&self.definitions
}
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
fn render<T: 'static + LanguageModelTool>(
tool_call_id: &str,
input: &Box<dyn Any>,
output: &Box<dyn Any>,
cx: &mut WindowContext,
) -> AnyElement {
T::render(
tool_call_id,
input.as_ref().downcast_ref::<T::Input>().unwrap(),
output.as_ref().downcast_ref::<T::Output>().unwrap(),
cx,
)
}
fn format<T: 'static + LanguageModelTool>(
input: &Box<dyn Any>,
output: &Box<dyn Any>,
) -> String {
T::format(
input.as_ref().downcast_ref::<T::Input>().unwrap(),
output.as_ref().downcast_ref::<T::Output>().unwrap(),
)
}
self.definitions.push(tool.definition());
let name = tool.name();
let previous = self.tools.insert(
name.clone(),
Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| {
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
return Task::ready(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ParsingFailed),
});
};
let result = tool.execute(&input, cx);
cx.spawn(move |_cx| async move {
match result.await {
Ok(result) => {
let result: T::Output = result;
ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::Finished {
input: Box::new(input),
output: Box::new(result),
render_fn: render::<T>,
format_fn: format::<T>,
}),
}
}
Err(_error) => ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ExecutionFailed {
input: Box::new(input),
}),
},
}
})
}),
);
if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name));
}
Ok(())
}
pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task<ToolFunctionCall> {
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
let tool = match self.tools.get(&name) {
Some(tool) => tool,
None => {
let name = name.clone();
return Task::ready(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::NoSuchTool),
});
}
};
tool(tool_call, cx)
}
}
#[cfg(test)]
mod test {
use super::*;
use schemars::schema_for;
use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize, Serialize, JsonSchema)]
struct WeatherQuery {
location: String,
unit: String,
}
struct WeatherTool {
current_weather: WeatherResult,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct WeatherResult {
location: String,
temperature: f64,
unit: String,
}
impl LanguageModelTool for WeatherTool {
type Input = WeatherQuery;
type Output = WeatherResult;
fn name(&self) -> String {
"get_current_weather".to_string()
}
fn description(&self) -> String {
"Fetches the current weather for a given location.".to_string()
}
fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task<Result<Self::Output>> {
let _location = input.location.clone();
let _unit = input.unit.clone();
let weather = self.current_weather.clone();
Task::ready(Ok(weather))
}
fn render(
_tool_call_id: &str,
_input: &Self::Input,
output: &Self::Output,
_cx: &mut WindowContext,
) -> AnyElement {
div()
.child(format!(
"The current temperature in {} is {} {}",
output.location, output.temperature, output.unit
))
.into_any()
}
fn format(_input: &Self::Input, output: &Self::Output) -> String {
format!(
"The current temperature in {} is {} {}",
output.location, output.temperature, output.unit
)
}
}
#[gpui::test]
async fn test_function_registry(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked();
let mut registry = ToolRegistry::new();
let tool = WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
};
registry.register(tool).unwrap();
let _result = cx
.update(|cx| {
registry.call(
&ToolFunctionCall {
name: "get_current_weather".to_string(),
arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
.to_string(),
id: "test-123".to_string(),
result: None,
},
cx,
)
})
.await;
// assert!(result.is_ok());
// let result = result.unwrap();
// let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
// todo!(): Put this back in after the interface is stabilized
// assert_eq!(result, expected);
}
#[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked();
let tool = WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
};
let tools = vec![tool.definition()];
assert_eq!(tools.len(), 1);
let expected = ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
parameters: schema_for!(WeatherQuery).schema,
};
assert_eq!(tools[0].name, expected.name);
assert_eq!(tools[0].description, expected.description);
let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
assert_eq!(
expected_schema,
json!({
"title": "WeatherQuery",
"type": "object",
"properties": {
"location": {
"type": "string"
},
"unit": {
"type": "string"
}
},
"required": ["location", "unit"]
})
);
let args = json!({
"location": "San Francisco",
"unit": "Celsius"
});
let query: WeatherQuery = serde_json::from_value(args).unwrap();
let result = cx.update(|cx| tool.execute(&query, cx)).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result, tool.current_weather);
}
}

View file

@ -0,0 +1,145 @@
use anyhow::Result;
use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
use schemars::{schema::SchemaObject, schema_for, JsonSchema};
use serde::Deserialize;
use std::{any::Any, fmt::Debug};
#[derive(Default, Deserialize)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
#[serde(skip)]
pub result: Option<ToolFunctionCallResult>,
}
pub enum ToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
ExecutionFailed {
input: Box<dyn Any>,
},
Finished {
input: Box<dyn Any>,
output: Box<dyn Any>,
render_fn: fn(
// tool_call_id
&str,
// LanguageModelTool::Input
&Box<dyn Any>,
// LanguageModelTool::Output
&Box<dyn Any>,
&mut WindowContext,
) -> AnyElement,
format_fn: fn(
// LanguageModelTool::Input
&Box<dyn Any>,
// LanguageModelTool::Output
&Box<dyn Any>,
) -> String,
},
}
impl ToolFunctionCallResult {
pub fn render(
&self,
tool_name: &str,
tool_call_id: &str,
cx: &mut WindowContext,
) -> AnyElement {
match self {
ToolFunctionCallResult::NoSuchTool => {
div().child(format!("no such tool {tool_name}")).into_any()
}
ToolFunctionCallResult::ParsingFailed => div()
.child(format!("failed to parse input for tool {tool_name}"))
.into_any(),
ToolFunctionCallResult::ExecutionFailed { .. } => div()
.child(format!("failed to execute tool {tool_name}"))
.into_any(),
ToolFunctionCallResult::Finished {
input,
output,
render_fn,
..
} => render_fn(tool_call_id, input, output, cx),
}
}
pub fn format(&self, tool: &str) -> String {
match self {
ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"),
ToolFunctionCallResult::ParsingFailed => {
format!("failed to parse input for tool {tool}")
}
ToolFunctionCallResult::ExecutionFailed { input: _input } => {
format!("failed to execute tool {tool}")
}
ToolFunctionCallResult::Finished {
input,
output,
format_fn,
..
} => format_fn(input, output),
}
}
}
#[derive(Clone)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
pub parameters: SchemaObject,
}
impl Debug for ToolFunctionDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let schema = serde_json::to_string(&self.parameters).ok();
let schema = schema.unwrap_or("None".to_string());
f.debug_struct("ToolFunctionDefinition")
.field("name", &self.name)
.field("description", &self.description)
.field("parameters", &schema)
.finish()
}
}
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: for<'de> Deserialize<'de> + JsonSchema;
/// The output returned by executing the tool.
type Output: 'static;
/// The name of the tool is exposed to the language model to allow
/// the model to pick which tools to use. As this name is used to
/// identify the tool within a tool registry, it should be unique.
fn name(&self) -> String;
/// A description of the tool that can be used to _prompt_ the model
/// as to what the tool does.
fn description(&self) -> String;
/// The OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
ToolFunctionDefinition {
name: self.name(),
description: self.description(),
parameters: schema_for!(Self::Input).schema,
}
}
/// Execute the tool
fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
fn render(
tool_call_id: &str,
input: &Self::Input,
output: &Self::Output,
cx: &mut WindowContext,
) -> AnyElement;
fn format(input: &Self::Input, output: &Self::Output) -> String;
}