New revision of the Assistant Panel (#10870)
This is a crate only addition of a new version of the AssistantPanel. We'll be putting this behind a feature flag while we iron out the new experience. Release Notes: - N/A --------- Co-authored-by: Nathan Sobo <nathan@zed.dev> Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Conrad Irwin <conrad@zed.dev> Co-authored-by: Marshall Bowers <elliott.codes@gmail.com> Co-authored-by: Antonio Scandurra <antonio@zed.dev> Co-authored-by: Nate Butler <nate@zed.dev> Co-authored-by: Nate Butler <iamnbutler@gmail.com> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com> Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
e0c83a1d32
commit
68a1ad89bb
55 changed files with 2989 additions and 262 deletions
22
crates/assistant_tooling/Cargo.toml
Normal file
22
crates/assistant_tooling/Cargo.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "assistant_tooling"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/assistant_tooling.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
gpui.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
1
crates/assistant_tooling/LICENSE-GPL
Symbolic link
1
crates/assistant_tooling/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
208
crates/assistant_tooling/README.md
Normal file
208
crates/assistant_tooling/README.md
Normal file
|
@ -0,0 +1,208 @@
|
|||
# Assistant Tooling
|
||||
|
||||
Bringing OpenAI compatible tool calling to GPUI.
|
||||
|
||||
This unlocks:
|
||||
|
||||
- **Structured Extraction** of model responses
|
||||
- **Validation** of model inputs
|
||||
- **Execution** of chosen toolsn
|
||||
|
||||
## Overview
|
||||
|
||||
Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When make a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call.
|
||||
|
||||
> **User**: "Hey I need help with implementing a collapsible panel in GPUI"
|
||||
>
|
||||
> **Assistant**: "Sure, I can help with that. Let me see what I can find."
|
||||
>
|
||||
> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]`
|
||||
>
|
||||
> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"`
|
||||
>
|
||||
> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you."
|
||||
|
||||
This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with a simple trait, `LanguageModelTool`.
|
||||
|
||||
## Example
|
||||
|
||||
Let's expose querying a semantic index directly by the model. First, we'll set up some _necessary_ imports
|
||||
|
||||
```rust
|
||||
use anyhow::Result;
|
||||
use assistant_tooling::{LanguageModelTool, ToolRegistry};
|
||||
use gpui::{App, AppContext, Task};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
```
|
||||
|
||||
Then we'll define the query structure the model must fill in. This _must_ derive `Deserialize` from `serde` and `JsonSchema` from the `schemars` crate.
|
||||
|
||||
```rust
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
struct CodebaseQuery {
|
||||
query: String,
|
||||
}
|
||||
```
|
||||
|
||||
After that we can define our tool, with the expectation that it will need a `ProjectIndex` to search against. For this example, the index uses the same interface as `semantic_index::ProjectIndex`.
|
||||
|
||||
```rust
|
||||
struct ProjectIndex {}
|
||||
|
||||
impl ProjectIndex {
|
||||
fn new() -> Self {
|
||||
ProjectIndex {}
|
||||
}
|
||||
|
||||
fn search(&self, _query: &str, _limit: usize, _cx: &AppContext) -> Task<Result<Vec<String>>> {
|
||||
// Instead of hooking up a real index, we're going to fake it
|
||||
if _query.contains("gpui") {
|
||||
return Task::ready(Ok(vec![r#"// crates/gpui/src/gpui.rs
|
||||
//! # Welcome to GPUI!
|
||||
//!
|
||||
//! GPUI is a hybrid immediate and retained mode, GPU accelerated, UI framework
|
||||
//! for Rust, designed to support a wide variety of applications
|
||||
"#
|
||||
.to_string()]));
|
||||
}
|
||||
return Task::ready(Ok(vec![]));
|
||||
}
|
||||
}
|
||||
|
||||
struct ProjectIndexTool {
|
||||
project_index: ProjectIndex,
|
||||
}
|
||||
```
|
||||
|
||||
Now we can implement the `LanguageModelTool` trait for our tool by:
|
||||
|
||||
- Defining the `Input` from the model, which is `CodebaseQuery`
|
||||
- Defining the `Output`
|
||||
- Implementing the `name` and `description` functions to provide the model information when it's choosing a tool
|
||||
- Implementing the `execute` function to run the tool
|
||||
|
||||
```rust
|
||||
impl LanguageModelTool for ProjectIndexTool {
|
||||
type Input = CodebaseQuery;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"query_codebase".to_string()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Executes a query against the codebase, returning excerpts related to the query".to_string()
|
||||
}
|
||||
|
||||
fn execute(&self, query: Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
|
||||
let results = self.project_index.search(query.query.as_str(), 10, cx);
|
||||
|
||||
cx.spawn(|_cx| async move {
|
||||
let results = results.await?;
|
||||
|
||||
if !results.is_empty() {
|
||||
Ok(results.join("\n"))
|
||||
} else {
|
||||
Ok("No results".to_string())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For the sake of this example, let's look at the types that OpenAI will be passing to us
|
||||
|
||||
```rust
|
||||
// OpenAI definitions, shown here for demonstration
|
||||
#[derive(Deserialize)]
|
||||
struct FunctionCall {
|
||||
name: String,
|
||||
args: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Eq, PartialEq)]
|
||||
enum ToolCallType {
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
|
||||
struct ToolCallId(String);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ToolCall {
|
||||
Function {
|
||||
#[allow(dead_code)]
|
||||
id: ToolCallId,
|
||||
function: FunctionCall,
|
||||
},
|
||||
Other {
|
||||
#[allow(dead_code)]
|
||||
id: ToolCallId,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AssistantMessage {
|
||||
role: String,
|
||||
content: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
```
|
||||
|
||||
When the model wants to call tools, it will pass a list of `ToolCall`s. When those are `function`s that we can handle, we'll pass them to our `ToolRegistry` to get a future that we can await.
|
||||
|
||||
```rust
|
||||
// Inside `fn main()`
|
||||
App::new().run(|cx: &mut AppContext| {
|
||||
let tool = ProjectIndexTool {
|
||||
project_index: ProjectIndex::new(),
|
||||
};
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
let registered = registry.register(tool);
|
||||
assert!(registered.is_ok());
|
||||
```
|
||||
|
||||
Let's pretend the model sent us back a message requesting
|
||||
|
||||
```rust
|
||||
let model_response = json!({
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "query_codebase",
|
||||
"args": r#"{"query":"GPUI Task background_executor"}"#
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let message: AssistantMessage = serde_json::from_value(model_response).unwrap();
|
||||
|
||||
// We know there's a tool call, so let's skip straight to it for this example
|
||||
let tool_calls = message.tool_calls.as_ref().unwrap();
|
||||
let tool_call = tool_calls.get(0).unwrap();
|
||||
```
|
||||
|
||||
We can now use our registry to call the tool.
|
||||
|
||||
```rust
|
||||
let task = registry.call(
|
||||
tool_call.name,
|
||||
tool_call.args,
|
||||
);
|
||||
|
||||
cx.spawn(|_cx| async move {
|
||||
let result = task.await?;
|
||||
println!("{}", result.unwrap());
|
||||
Ok(())
|
||||
})
|
||||
```
|
5
crates/assistant_tooling/src/assistant_tooling.rs
Normal file
5
crates/assistant_tooling/src/assistant_tooling.rs
Normal file
|
@ -0,0 +1,5 @@
|
|||
pub mod registry;
|
||||
pub mod tool;
|
||||
|
||||
pub use crate::registry::ToolRegistry;
|
||||
pub use crate::tool::{LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition};
|
298
crates/assistant_tooling/src/registry.rs
Normal file
298
crates/assistant_tooling/src/registry.rs
Normal file
|
@ -0,0 +1,298 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use gpui::{AnyElement, AppContext, Task, WindowContext};
|
||||
use std::{any::Any, collections::HashMap};
|
||||
|
||||
use crate::tool::{
|
||||
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
||||
};
|
||||
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<String, Box<dyn Fn(&ToolFunctionCall, &AppContext) -> Task<ToolFunctionCall>>>,
|
||||
definitions: Vec<ToolFunctionDefinition>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tools: HashMap::new(),
|
||||
definitions: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn definitions(&self) -> &[ToolFunctionDefinition] {
|
||||
&self.definitions
|
||||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
|
||||
fn render<T: 'static + LanguageModelTool>(
|
||||
tool_call_id: &str,
|
||||
input: &Box<dyn Any>,
|
||||
output: &Box<dyn Any>,
|
||||
cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
T::render(
|
||||
tool_call_id,
|
||||
input.as_ref().downcast_ref::<T::Input>().unwrap(),
|
||||
output.as_ref().downcast_ref::<T::Output>().unwrap(),
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn format<T: 'static + LanguageModelTool>(
|
||||
input: &Box<dyn Any>,
|
||||
output: &Box<dyn Any>,
|
||||
) -> String {
|
||||
T::format(
|
||||
input.as_ref().downcast_ref::<T::Input>().unwrap(),
|
||||
output.as_ref().downcast_ref::<T::Output>().unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
self.definitions.push(tool.definition());
|
||||
let name = tool.name();
|
||||
let previous = self.tools.insert(
|
||||
name.clone(),
|
||||
Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| {
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
|
||||
return Task::ready(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::ParsingFailed),
|
||||
});
|
||||
};
|
||||
|
||||
let result = tool.execute(&input, cx);
|
||||
|
||||
cx.spawn(move |_cx| async move {
|
||||
match result.await {
|
||||
Ok(result) => {
|
||||
let result: T::Output = result;
|
||||
ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::Finished {
|
||||
input: Box::new(input),
|
||||
output: Box::new(result),
|
||||
render_fn: render::<T>,
|
||||
format_fn: format::<T>,
|
||||
}),
|
||||
}
|
||||
}
|
||||
Err(_error) => ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::ExecutionFailed {
|
||||
input: Box::new(input),
|
||||
}),
|
||||
},
|
||||
}
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
if previous.is_some() {
|
||||
return Err(anyhow!("already registered a tool with name {}", name));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task<ToolFunctionCall> {
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
|
||||
let tool = match self.tools.get(&name) {
|
||||
Some(tool) => tool,
|
||||
None => {
|
||||
let name = name.clone();
|
||||
return Task::ready(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::NoSuchTool),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
tool(tool_call, cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use super::*;
|
||||
|
||||
use schemars::schema_for;
|
||||
|
||||
use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema)]
|
||||
struct WeatherQuery {
|
||||
location: String,
|
||||
unit: String,
|
||||
}
|
||||
|
||||
struct WeatherTool {
|
||||
current_weather: WeatherResult,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
|
||||
struct WeatherResult {
|
||||
location: String,
|
||||
temperature: f64,
|
||||
unit: String,
|
||||
}
|
||||
|
||||
impl LanguageModelTool for WeatherTool {
|
||||
type Input = WeatherQuery;
|
||||
type Output = WeatherResult;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"get_current_weather".to_string()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Fetches the current weather for a given location.".to_string()
|
||||
}
|
||||
|
||||
fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task<Result<Self::Output>> {
|
||||
let _location = input.location.clone();
|
||||
let _unit = input.unit.clone();
|
||||
|
||||
let weather = self.current_weather.clone();
|
||||
|
||||
Task::ready(Ok(weather))
|
||||
}
|
||||
|
||||
fn render(
|
||||
_tool_call_id: &str,
|
||||
_input: &Self::Input,
|
||||
output: &Self::Output,
|
||||
_cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
div()
|
||||
.child(format!(
|
||||
"The current temperature in {} is {} {}",
|
||||
output.location, output.temperature, output.unit
|
||||
))
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn format(_input: &Self::Input, output: &Self::Output) -> String {
|
||||
format!(
|
||||
"The current temperature in {} is {} {}",
|
||||
output.location, output.temperature, output.unit
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_function_registry(cx: &mut TestAppContext) {
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
|
||||
let tool = WeatherTool {
|
||||
current_weather: WeatherResult {
|
||||
location: "San Francisco".to_string(),
|
||||
temperature: 21.0,
|
||||
unit: "Celsius".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
registry.register(tool).unwrap();
|
||||
|
||||
let _result = cx
|
||||
.update(|cx| {
|
||||
registry.call(
|
||||
&ToolFunctionCall {
|
||||
name: "get_current_weather".to_string(),
|
||||
arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
|
||||
.to_string(),
|
||||
id: "test-123".to_string(),
|
||||
result: None,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
// assert!(result.is_ok());
|
||||
// let result = result.unwrap();
|
||||
|
||||
// let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
|
||||
|
||||
// todo!(): Put this back in after the interface is stabilized
|
||||
// assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_openai_weather_example(cx: &mut TestAppContext) {
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
let tool = WeatherTool {
|
||||
current_weather: WeatherResult {
|
||||
location: "San Francisco".to_string(),
|
||||
temperature: 21.0,
|
||||
unit: "Celsius".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let tools = vec![tool.definition()];
|
||||
assert_eq!(tools.len(), 1);
|
||||
|
||||
let expected = ToolFunctionDefinition {
|
||||
name: "get_current_weather".to_string(),
|
||||
description: "Fetches the current weather for a given location.".to_string(),
|
||||
parameters: schema_for!(WeatherQuery).schema,
|
||||
};
|
||||
|
||||
assert_eq!(tools[0].name, expected.name);
|
||||
assert_eq!(tools[0].description, expected.description);
|
||||
|
||||
let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
expected_schema,
|
||||
json!({
|
||||
"title": "WeatherQuery",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["location", "unit"]
|
||||
})
|
||||
);
|
||||
|
||||
let args = json!({
|
||||
"location": "San Francisco",
|
||||
"unit": "Celsius"
|
||||
});
|
||||
|
||||
let query: WeatherQuery = serde_json::from_value(args).unwrap();
|
||||
|
||||
let result = cx.update(|cx| tool.execute(&query, cx)).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
|
||||
assert_eq!(result, tool.current_weather);
|
||||
}
|
||||
}
|
145
crates/assistant_tooling/src/tool.rs
Normal file
145
crates/assistant_tooling/src/tool.rs
Normal file
|
@ -0,0 +1,145 @@
|
|||
use anyhow::Result;
|
||||
use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
|
||||
use schemars::{schema::SchemaObject, schema_for, JsonSchema};
|
||||
use serde::Deserialize;
|
||||
use std::{any::Any, fmt::Debug};
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
pub struct ToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
#[serde(skip)]
|
||||
pub result: Option<ToolFunctionCallResult>,
|
||||
}
|
||||
|
||||
pub enum ToolFunctionCallResult {
|
||||
NoSuchTool,
|
||||
ParsingFailed,
|
||||
ExecutionFailed {
|
||||
input: Box<dyn Any>,
|
||||
},
|
||||
Finished {
|
||||
input: Box<dyn Any>,
|
||||
output: Box<dyn Any>,
|
||||
render_fn: fn(
|
||||
// tool_call_id
|
||||
&str,
|
||||
// LanguageModelTool::Input
|
||||
&Box<dyn Any>,
|
||||
// LanguageModelTool::Output
|
||||
&Box<dyn Any>,
|
||||
&mut WindowContext,
|
||||
) -> AnyElement,
|
||||
format_fn: fn(
|
||||
// LanguageModelTool::Input
|
||||
&Box<dyn Any>,
|
||||
// LanguageModelTool::Output
|
||||
&Box<dyn Any>,
|
||||
) -> String,
|
||||
},
|
||||
}
|
||||
|
||||
impl ToolFunctionCallResult {
|
||||
pub fn render(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
tool_call_id: &str,
|
||||
cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
match self {
|
||||
ToolFunctionCallResult::NoSuchTool => {
|
||||
div().child(format!("no such tool {tool_name}")).into_any()
|
||||
}
|
||||
ToolFunctionCallResult::ParsingFailed => div()
|
||||
.child(format!("failed to parse input for tool {tool_name}"))
|
||||
.into_any(),
|
||||
ToolFunctionCallResult::ExecutionFailed { .. } => div()
|
||||
.child(format!("failed to execute tool {tool_name}"))
|
||||
.into_any(),
|
||||
ToolFunctionCallResult::Finished {
|
||||
input,
|
||||
output,
|
||||
render_fn,
|
||||
..
|
||||
} => render_fn(tool_call_id, input, output, cx),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn format(&self, tool: &str) -> String {
|
||||
match self {
|
||||
ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"),
|
||||
ToolFunctionCallResult::ParsingFailed => {
|
||||
format!("failed to parse input for tool {tool}")
|
||||
}
|
||||
ToolFunctionCallResult::ExecutionFailed { input: _input } => {
|
||||
format!("failed to execute tool {tool}")
|
||||
}
|
||||
ToolFunctionCallResult::Finished {
|
||||
input,
|
||||
output,
|
||||
format_fn,
|
||||
..
|
||||
} => format_fn(input, output),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolFunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: SchemaObject,
|
||||
}
|
||||
|
||||
impl Debug for ToolFunctionDefinition {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let schema = serde_json::to_string(&self.parameters).ok();
|
||||
let schema = schema.unwrap_or("None".to_string());
|
||||
|
||||
f.debug_struct("ToolFunctionDefinition")
|
||||
.field("name", &self.name)
|
||||
.field("description", &self.description)
|
||||
.field("parameters", &schema)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool {
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
|
||||
/// The output returned by executing the tool.
|
||||
type Output: 'static;
|
||||
|
||||
/// The name of the tool is exposed to the language model to allow
|
||||
/// the model to pick which tools to use. As this name is used to
|
||||
/// identify the tool within a tool registry, it should be unique.
|
||||
fn name(&self) -> String;
|
||||
|
||||
/// A description of the tool that can be used to _prompt_ the model
|
||||
/// as to what the tool does.
|
||||
fn description(&self) -> String;
|
||||
|
||||
/// The OpenAI Function definition for the tool, for direct use with OpenAI's API.
|
||||
fn definition(&self) -> ToolFunctionDefinition {
|
||||
ToolFunctionDefinition {
|
||||
name: self.name(),
|
||||
description: self.description(),
|
||||
parameters: schema_for!(Self::Input).schema,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the tool
|
||||
fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
|
||||
|
||||
fn render(
|
||||
tool_call_id: &str,
|
||||
input: &Self::Input,
|
||||
output: &Self::Output,
|
||||
cx: &mut WindowContext,
|
||||
) -> AnyElement;
|
||||
|
||||
fn format(input: &Self::Input, output: &Self::Output) -> String;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue