Accept View
s on LanguageModelTool
s (#10956)
Creates a `ToolView` trait to allow interactivity. This brings expanding and collapsing to the excerpts from project index searches. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
7005f0b424
commit
f176e8f0e4
5 changed files with 268 additions and 269 deletions
|
@ -1,13 +1,16 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use gpui::{AnyElement, AppContext, Task, WindowContext};
|
||||
use std::{any::Any, collections::HashMap};
|
||||
use gpui::{Task, WindowContext};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::tool::{
|
||||
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
||||
};
|
||||
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<String, Box<dyn Fn(&ToolFunctionCall, &AppContext) -> Task<ToolFunctionCall>>>,
|
||||
tools: HashMap<
|
||||
String,
|
||||
Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
>,
|
||||
definitions: Vec<ToolFunctionDefinition>,
|
||||
}
|
||||
|
||||
|
@ -24,77 +27,45 @@ impl ToolRegistry {
|
|||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
|
||||
fn render<T: 'static + LanguageModelTool>(
|
||||
tool_call_id: &str,
|
||||
input: &Box<dyn Any>,
|
||||
output: &Box<dyn Any>,
|
||||
cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
T::render(
|
||||
tool_call_id,
|
||||
input.as_ref().downcast_ref::<T::Input>().unwrap(),
|
||||
output.as_ref().downcast_ref::<T::Output>().unwrap(),
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn format<T: 'static + LanguageModelTool>(
|
||||
input: &Box<dyn Any>,
|
||||
output: &Box<dyn Any>,
|
||||
) -> String {
|
||||
T::format(
|
||||
input.as_ref().downcast_ref::<T::Input>().unwrap(),
|
||||
output.as_ref().downcast_ref::<T::Output>().unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
self.definitions.push(tool.definition());
|
||||
let name = tool.name();
|
||||
let previous = self.tools.insert(
|
||||
name.clone(),
|
||||
Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| {
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
// registry.call(tool_call, cx)
|
||||
Box::new(
|
||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
|
||||
return Task::ready(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::ParsingFailed),
|
||||
});
|
||||
};
|
||||
|
||||
let result = tool.execute(&input, cx);
|
||||
|
||||
cx.spawn(move |_cx| async move {
|
||||
match result.await {
|
||||
Ok(result) => {
|
||||
let result: T::Output = result;
|
||||
ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::Finished {
|
||||
input: Box::new(input),
|
||||
output: Box::new(result),
|
||||
render_fn: render::<T>,
|
||||
format_fn: format::<T>,
|
||||
}),
|
||||
}
|
||||
}
|
||||
Err(_error) => ToolFunctionCall {
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
|
||||
return Task::ready(Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::ExecutionFailed {
|
||||
input: Box::new(input),
|
||||
result: Some(ToolFunctionCallResult::ParsingFailed),
|
||||
}));
|
||||
};
|
||||
|
||||
let result = tool.execute(&input, cx);
|
||||
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<T::Output> = result.await;
|
||||
let for_model = T::format(&input, &result);
|
||||
let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?;
|
||||
|
||||
Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::Finished {
|
||||
view: view.into(),
|
||||
for_model,
|
||||
}),
|
||||
},
|
||||
}
|
||||
})
|
||||
}),
|
||||
})
|
||||
})
|
||||
},
|
||||
),
|
||||
);
|
||||
|
||||
if previous.is_some() {
|
||||
|
@ -104,7 +75,12 @@ impl ToolRegistry {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task<ToolFunctionCall> {
|
||||
/// Task yields an error if the window for the given WindowContext is closed before the task completes.
|
||||
pub fn call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<ToolFunctionCall>> {
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
|
@ -113,12 +89,12 @@ impl ToolRegistry {
|
|||
Some(tool) => tool,
|
||||
None => {
|
||||
let name = name.clone();
|
||||
return Task::ready(ToolFunctionCall {
|
||||
return Task::ready(Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::NoSuchTool),
|
||||
});
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -128,12 +104,10 @@ impl ToolRegistry {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use super::*;
|
||||
|
||||
use gpui::View;
|
||||
use gpui::{div, prelude::*, Render, TestAppContext};
|
||||
use schemars::schema_for;
|
||||
|
||||
use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
@ -155,9 +129,20 @@ mod test {
|
|||
unit: String,
|
||||
}
|
||||
|
||||
struct WeatherView {
|
||||
result: WeatherResult,
|
||||
}
|
||||
|
||||
impl Render for WeatherView {
|
||||
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
|
||||
div().child(format!("temperature: {}", self.result.temperature))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelTool for WeatherTool {
|
||||
type Input = WeatherQuery;
|
||||
type Output = WeatherResult;
|
||||
type View = WeatherView;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"get_current_weather".to_string()
|
||||
|
@ -167,7 +152,11 @@ mod test {
|
|||
"Fetches the current weather for a given location.".to_string()
|
||||
}
|
||||
|
||||
fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task<Result<Self::Output>> {
|
||||
fn execute(
|
||||
&self,
|
||||
input: &Self::Input,
|
||||
_cx: &gpui::AppContext,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
let _location = input.location.clone();
|
||||
let _unit = input.unit.clone();
|
||||
|
||||
|
@ -176,25 +165,20 @@ mod test {
|
|||
Task::ready(Ok(weather))
|
||||
}
|
||||
|
||||
fn render(
|
||||
_tool_call_id: &str,
|
||||
_input: &Self::Input,
|
||||
output: &Self::Output,
|
||||
_cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
div()
|
||||
.child(format!(
|
||||
"The current temperature in {} is {} {}",
|
||||
output.location, output.temperature, output.unit
|
||||
))
|
||||
.into_any()
|
||||
fn new_view(
|
||||
_tool_call_id: String,
|
||||
_input: Self::Input,
|
||||
result: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View> {
|
||||
cx.new_view(|_cx| {
|
||||
let result = result.unwrap();
|
||||
WeatherView { result }
|
||||
})
|
||||
}
|
||||
|
||||
fn format(_input: &Self::Input, output: &Self::Output) -> String {
|
||||
format!(
|
||||
"The current temperature in {} is {} {}",
|
||||
output.location, output.temperature, output.unit
|
||||
)
|
||||
fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
|
||||
serde_json::to_string(&output.as_ref().unwrap()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -214,20 +198,20 @@ mod test {
|
|||
|
||||
registry.register(tool).unwrap();
|
||||
|
||||
let _result = cx
|
||||
.update(|cx| {
|
||||
registry.call(
|
||||
&ToolFunctionCall {
|
||||
name: "get_current_weather".to_string(),
|
||||
arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
|
||||
.to_string(),
|
||||
id: "test-123".to_string(),
|
||||
result: None,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
// let _result = cx
|
||||
// .update(|cx| {
|
||||
// registry.call(
|
||||
// &ToolFunctionCall {
|
||||
// name: "get_current_weather".to_string(),
|
||||
// arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
|
||||
// .to_string(),
|
||||
// id: "test-123".to_string(),
|
||||
// result: None,
|
||||
// },
|
||||
// cx,
|
||||
// )
|
||||
// })
|
||||
// .await;
|
||||
|
||||
// assert!(result.is_ok());
|
||||
// let result = result.unwrap();
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
use anyhow::Result;
|
||||
use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
|
||||
use gpui::{AnyElement, AnyView, AppContext, IntoElement as _, Render, Task, View, WindowContext};
|
||||
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
any::Any,
|
||||
fmt::{Debug, Display},
|
||||
};
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
pub struct ToolFunctionCall {
|
||||
|
@ -19,71 +16,29 @@ pub struct ToolFunctionCall {
|
|||
pub enum ToolFunctionCallResult {
|
||||
NoSuchTool,
|
||||
ParsingFailed,
|
||||
ExecutionFailed {
|
||||
input: Box<dyn Any>,
|
||||
},
|
||||
Finished {
|
||||
input: Box<dyn Any>,
|
||||
output: Box<dyn Any>,
|
||||
render_fn: fn(
|
||||
// tool_call_id
|
||||
&str,
|
||||
// LanguageModelTool::Input
|
||||
&Box<dyn Any>,
|
||||
// LanguageModelTool::Output
|
||||
&Box<dyn Any>,
|
||||
&mut WindowContext,
|
||||
) -> AnyElement,
|
||||
format_fn: fn(
|
||||
// LanguageModelTool::Input
|
||||
&Box<dyn Any>,
|
||||
// LanguageModelTool::Output
|
||||
&Box<dyn Any>,
|
||||
) -> String,
|
||||
},
|
||||
Finished { for_model: String, view: AnyView },
|
||||
}
|
||||
|
||||
impl ToolFunctionCallResult {
|
||||
pub fn render(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
tool_call_id: &str,
|
||||
cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
pub fn format(&self, name: &String) -> String {
|
||||
match self {
|
||||
ToolFunctionCallResult::NoSuchTool => {
|
||||
div().child(format!("no such tool {tool_name}")).into_any()
|
||||
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
|
||||
ToolFunctionCallResult::ParsingFailed => {
|
||||
format!("Unable to parse arguments for {name}")
|
||||
}
|
||||
ToolFunctionCallResult::ParsingFailed => div()
|
||||
.child(format!("failed to parse input for tool {tool_name}"))
|
||||
.into_any(),
|
||||
ToolFunctionCallResult::ExecutionFailed { .. } => div()
|
||||
.child(format!("failed to execute tool {tool_name}"))
|
||||
.into_any(),
|
||||
ToolFunctionCallResult::Finished {
|
||||
input,
|
||||
output,
|
||||
render_fn,
|
||||
..
|
||||
} => render_fn(tool_call_id, input, output, cx),
|
||||
ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn format(&self, tool: &str) -> String {
|
||||
pub fn into_any_element(&self, name: &String) -> AnyElement {
|
||||
match self {
|
||||
ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"),
|
||||
ToolFunctionCallResult::NoSuchTool => {
|
||||
format!("Language Model attempted to call {name}").into_any_element()
|
||||
}
|
||||
ToolFunctionCallResult::ParsingFailed => {
|
||||
format!("failed to parse input for tool {tool}")
|
||||
format!("Language Model called {name} with bad arguments").into_any_element()
|
||||
}
|
||||
ToolFunctionCallResult::ExecutionFailed { input: _input } => {
|
||||
format!("failed to execute tool {tool}")
|
||||
}
|
||||
ToolFunctionCallResult::Finished {
|
||||
input,
|
||||
output,
|
||||
format_fn,
|
||||
..
|
||||
} => format_fn(input, output),
|
||||
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -105,19 +60,6 @@ impl Display for ToolFunctionDefinition {
|
|||
}
|
||||
}
|
||||
|
||||
impl Debug for ToolFunctionDefinition {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let schema = serde_json::to_string(&self.parameters).ok();
|
||||
let schema = schema.unwrap_or("None".to_string());
|
||||
|
||||
f.debug_struct("ToolFunctionDefinition")
|
||||
.field("name", &self.name)
|
||||
.field("description", &self.description)
|
||||
.field("parameters", &schema)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool {
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
|
@ -126,6 +68,8 @@ pub trait LanguageModelTool {
|
|||
/// The output returned by executing the tool.
|
||||
type Output: 'static;
|
||||
|
||||
type View: Render;
|
||||
|
||||
/// The name of the tool is exposed to the language model to allow
|
||||
/// the model to pick which tools to use. As this name is used to
|
||||
/// identify the tool within a tool registry, it should be unique.
|
||||
|
@ -149,12 +93,12 @@ pub trait LanguageModelTool {
|
|||
/// Execute the tool
|
||||
fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
|
||||
|
||||
fn render(
|
||||
tool_call_id: &str,
|
||||
input: &Self::Input,
|
||||
output: &Self::Output,
|
||||
cx: &mut WindowContext,
|
||||
) -> AnyElement;
|
||||
fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
|
||||
|
||||
fn format(input: &Self::Input, output: &Self::Output) -> String;
|
||||
fn new_view(
|
||||
tool_call_id: String,
|
||||
input: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View>;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue