Accept View
s on LanguageModelTool
s (#10956)
Creates a `ToolView` trait to allow interactivity. This brings expanding and collapsing to the excerpts from project index searches. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
7005f0b424
commit
f176e8f0e4
5 changed files with 268 additions and 269 deletions
|
@ -1,4 +1,5 @@
|
||||||
use anyhow::Context as _;
|
/// This example creates a basic Chat UI with a function for rolling a die.
|
||||||
|
use anyhow::{Context as _, Result};
|
||||||
use assets::Assets;
|
use assets::Assets;
|
||||||
use assistant2::AssistantPanel;
|
use assistant2::AssistantPanel;
|
||||||
use assistant_tooling::{LanguageModelTool, ToolRegistry};
|
use assistant_tooling::{LanguageModelTool, ToolRegistry};
|
||||||
|
@ -83,9 +84,32 @@ struct DiceRoll {
|
||||||
rolls: Vec<DieRoll>,
|
rolls: Vec<DieRoll>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct DiceView {
|
||||||
|
result: Result<DiceRoll>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Render for DiceView {
|
||||||
|
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
|
let output = match &self.result {
|
||||||
|
Ok(output) => output,
|
||||||
|
Err(_) => return "Somehow dice failed 🎲".into_any_element(),
|
||||||
|
};
|
||||||
|
|
||||||
|
h_flex()
|
||||||
|
.children(
|
||||||
|
output
|
||||||
|
.rolls
|
||||||
|
.iter()
|
||||||
|
.map(|roll| div().p_2().child(roll.render())),
|
||||||
|
)
|
||||||
|
.into_any_element()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl LanguageModelTool for RollDiceTool {
|
impl LanguageModelTool for RollDiceTool {
|
||||||
type Input = DiceParams;
|
type Input = DiceParams;
|
||||||
type Output = DiceRoll;
|
type Output = DiceRoll;
|
||||||
|
type View = DiceView;
|
||||||
|
|
||||||
fn name(&self) -> String {
|
fn name(&self) -> String {
|
||||||
"roll_dice".to_string()
|
"roll_dice".to_string()
|
||||||
|
@ -110,23 +134,21 @@ impl LanguageModelTool for RollDiceTool {
|
||||||
return Task::ready(Ok(DiceRoll { rolls }));
|
return Task::ready(Ok(DiceRoll { rolls }));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render(
|
fn new_view(
|
||||||
_tool_call_id: &str,
|
_tool_call_id: String,
|
||||||
_input: &Self::Input,
|
_input: Self::Input,
|
||||||
output: &Self::Output,
|
result: Result<Self::Output>,
|
||||||
_cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
) -> gpui::AnyElement {
|
) -> gpui::View<Self::View> {
|
||||||
h_flex()
|
cx.new_view(|_cx| DiceView { result })
|
||||||
.children(
|
|
||||||
output
|
|
||||||
.rolls
|
|
||||||
.iter()
|
|
||||||
.map(|roll| div().p_2().child(roll.render())),
|
|
||||||
)
|
|
||||||
.into_any_element()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format(_input: &Self::Input, output: &Self::Output) -> String {
|
fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
|
||||||
|
let output = match output {
|
||||||
|
Ok(output) => output,
|
||||||
|
Err(_) => return "Somehow dice failed 🎲".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
let mut result = String::new();
|
let mut result = String::new();
|
||||||
for roll in &output.rolls {
|
for roll in &output.rolls {
|
||||||
let die = &roll.die;
|
let die = &roll.die;
|
||||||
|
|
|
@ -322,9 +322,11 @@ impl AssistantChat {
|
||||||
};
|
};
|
||||||
call_count += 1;
|
call_count += 1;
|
||||||
|
|
||||||
|
let messages = this.completion_messages(cx);
|
||||||
|
|
||||||
CompletionProvider::get(cx).complete(
|
CompletionProvider::get(cx).complete(
|
||||||
this.model.clone(),
|
this.model.clone(),
|
||||||
this.completion_messages(cx),
|
messages,
|
||||||
Vec::new(),
|
Vec::new(),
|
||||||
1.0,
|
1.0,
|
||||||
definitions,
|
definitions,
|
||||||
|
@ -407,6 +409,10 @@ impl AssistantChat {
|
||||||
}
|
}
|
||||||
|
|
||||||
let tools = join_all(tool_tasks.into_iter()).await;
|
let tools = join_all(tool_tasks.into_iter()).await;
|
||||||
|
// If the WindowContext went away for any tool's view we don't include it
|
||||||
|
// especially since the below call would fail for the same reason.
|
||||||
|
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
|
||||||
|
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
|
if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
|
||||||
this.messages.last_mut()
|
this.messages.last_mut()
|
||||||
|
@ -561,10 +567,9 @@ impl AssistantChat {
|
||||||
let result = &tool_call.result;
|
let result = &tool_call.result;
|
||||||
let name = tool_call.name.clone();
|
let name = tool_call.name.clone();
|
||||||
match result {
|
match result {
|
||||||
Some(result) => div()
|
Some(result) => {
|
||||||
.p_2()
|
div().p_2().child(result.into_any_element(&name)).into_any()
|
||||||
.child(result.render(&name, &tool_call.id, cx))
|
}
|
||||||
.into_any(),
|
|
||||||
None => div()
|
None => div()
|
||||||
.p_2()
|
.p_2()
|
||||||
.child(Label::new(name).color(Color::Modified))
|
.child(Label::new(name).color(Color::Modified))
|
||||||
|
@ -577,7 +582,7 @@ impl AssistantChat {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn completion_messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
|
fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
|
||||||
let mut completion_messages = Vec::new();
|
let mut completion_messages = Vec::new();
|
||||||
|
|
||||||
for message in &self.messages {
|
for message in &self.messages {
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use assistant_tooling::LanguageModelTool;
|
use assistant_tooling::LanguageModelTool;
|
||||||
use gpui::{prelude::*, AnyElement, AppContext, Model, Task};
|
use gpui::{prelude::*, AppContext, Model, Task};
|
||||||
use project::Fs;
|
use project::Fs;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use semantic_index::ProjectIndex;
|
use semantic_index::ProjectIndex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::Deserialize;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use ui::{
|
use ui::{
|
||||||
div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
|
div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, SharedString,
|
||||||
|
@ -14,11 +14,13 @@ use util::ResultExt as _;
|
||||||
|
|
||||||
const DEFAULT_SEARCH_LIMIT: usize = 20;
|
const DEFAULT_SEARCH_LIMIT: usize = 20;
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
#[derive(Clone)]
|
||||||
pub struct CodebaseExcerpt {
|
pub struct CodebaseExcerpt {
|
||||||
path: SharedString,
|
path: SharedString,
|
||||||
text: SharedString,
|
text: SharedString,
|
||||||
score: f32,
|
score: f32,
|
||||||
|
element_id: ElementId,
|
||||||
|
expanded: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
|
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
|
||||||
|
@ -32,6 +34,79 @@ pub struct CodebaseQuery {
|
||||||
limit: Option<usize>,
|
limit: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ProjectIndexView {
|
||||||
|
input: CodebaseQuery,
|
||||||
|
output: Result<Vec<CodebaseExcerpt>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProjectIndexView {
|
||||||
|
fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
|
||||||
|
if let Ok(excerpts) = &mut self.output {
|
||||||
|
if let Some(excerpt) = excerpts
|
||||||
|
.iter_mut()
|
||||||
|
.find(|excerpt| excerpt.element_id == element_id)
|
||||||
|
{
|
||||||
|
excerpt.expanded = !excerpt.expanded;
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Render for ProjectIndexView {
|
||||||
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||||
|
let query = self.input.query.clone();
|
||||||
|
|
||||||
|
let result = &self.output;
|
||||||
|
|
||||||
|
let excerpts = match result {
|
||||||
|
Err(err) => {
|
||||||
|
return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
|
||||||
|
}
|
||||||
|
Ok(excerpts) => excerpts,
|
||||||
|
};
|
||||||
|
|
||||||
|
div()
|
||||||
|
.v_flex()
|
||||||
|
.gap_2()
|
||||||
|
.child(
|
||||||
|
div()
|
||||||
|
.p_2()
|
||||||
|
.rounded_md()
|
||||||
|
.bg(cx.theme().colors().editor_background)
|
||||||
|
.child(
|
||||||
|
h_flex()
|
||||||
|
.child(Label::new("Query: ").color(Color::Modified))
|
||||||
|
.child(Label::new(query).color(Color::Muted)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.children(excerpts.iter().map(|excerpt| {
|
||||||
|
let element_id = excerpt.element_id.clone();
|
||||||
|
let expanded = excerpt.expanded;
|
||||||
|
|
||||||
|
CollapsibleContainer::new(element_id.clone(), expanded)
|
||||||
|
.start_slot(
|
||||||
|
h_flex()
|
||||||
|
.gap_1()
|
||||||
|
.child(Icon::new(IconName::File).color(Color::Muted))
|
||||||
|
.child(Label::new(excerpt.path.clone()).color(Color::Muted)),
|
||||||
|
)
|
||||||
|
.on_click(cx.listener(move |this, _, cx| {
|
||||||
|
this.toggle_expanded(element_id.clone(), cx);
|
||||||
|
}))
|
||||||
|
.child(
|
||||||
|
div()
|
||||||
|
.p_2()
|
||||||
|
.rounded_md()
|
||||||
|
.bg(cx.theme().colors().editor_background)
|
||||||
|
.child(
|
||||||
|
excerpt.text.clone(), // todo!(): Show as an editor block
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ProjectIndexTool {
|
pub struct ProjectIndexTool {
|
||||||
project_index: Model<ProjectIndex>,
|
project_index: Model<ProjectIndex>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
|
@ -47,6 +122,7 @@ impl ProjectIndexTool {
|
||||||
impl LanguageModelTool for ProjectIndexTool {
|
impl LanguageModelTool for ProjectIndexTool {
|
||||||
type Input = CodebaseQuery;
|
type Input = CodebaseQuery;
|
||||||
type Output = Vec<CodebaseExcerpt>;
|
type Output = Vec<CodebaseExcerpt>;
|
||||||
|
type View = ProjectIndexView;
|
||||||
|
|
||||||
fn name(&self) -> String {
|
fn name(&self) -> String {
|
||||||
"query_codebase".to_string()
|
"query_codebase".to_string()
|
||||||
|
@ -90,6 +166,8 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::Ok(CodebaseExcerpt {
|
anyhow::Ok(CodebaseExcerpt {
|
||||||
|
element_id: ElementId::Name(nanoid::nanoid!().into()),
|
||||||
|
expanded: false,
|
||||||
path: path.to_string_lossy().to_string().into(),
|
path: path.to_string_lossy().to_string().into(),
|
||||||
text: SharedString::from(text[start..end].to_string()),
|
text: SharedString::from(text[start..end].to_string()),
|
||||||
score: result.score,
|
score: result.score,
|
||||||
|
@ -106,71 +184,37 @@ impl LanguageModelTool for ProjectIndexTool {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render(
|
fn new_view(
|
||||||
_tool_call_id: &str,
|
_tool_call_id: String,
|
||||||
input: &Self::Input,
|
input: Self::Input,
|
||||||
excerpts: &Self::Output,
|
output: Result<Self::Output>,
|
||||||
cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
) -> AnyElement {
|
) -> gpui::View<Self::View> {
|
||||||
let query = input.query.clone();
|
cx.new_view(|_cx| ProjectIndexView { input, output })
|
||||||
|
|
||||||
div()
|
|
||||||
.v_flex()
|
|
||||||
.gap_2()
|
|
||||||
.child(
|
|
||||||
div()
|
|
||||||
.p_2()
|
|
||||||
.rounded_md()
|
|
||||||
.bg(cx.theme().colors().editor_background)
|
|
||||||
.child(
|
|
||||||
h_flex()
|
|
||||||
.child(Label::new("Query: ").color(Color::Modified))
|
|
||||||
.child(Label::new(query).color(Color::Muted)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.children(excerpts.iter().map(|excerpt| {
|
|
||||||
// This render doesn't have state/model, so we can't use the listener
|
|
||||||
// let expanded = excerpt.expanded;
|
|
||||||
// let element_id = excerpt.element_id.clone();
|
|
||||||
let element_id = ElementId::Name(nanoid::nanoid!().into());
|
|
||||||
let expanded = false;
|
|
||||||
|
|
||||||
CollapsibleContainer::new(element_id.clone(), expanded)
|
|
||||||
.start_slot(
|
|
||||||
h_flex()
|
|
||||||
.gap_1()
|
|
||||||
.child(Icon::new(IconName::File).color(Color::Muted))
|
|
||||||
.child(Label::new(excerpt.path.clone()).color(Color::Muted)),
|
|
||||||
)
|
|
||||||
// .on_click(cx.listener(move |this, _, cx| {
|
|
||||||
// this.toggle_expanded(element_id.clone(), cx);
|
|
||||||
// }))
|
|
||||||
.child(
|
|
||||||
div()
|
|
||||||
.p_2()
|
|
||||||
.rounded_md()
|
|
||||||
.bg(cx.theme().colors().editor_background)
|
|
||||||
.child(
|
|
||||||
excerpt.text.clone(), // todo!(): Show as an editor block
|
|
||||||
),
|
|
||||||
)
|
|
||||||
}))
|
|
||||||
.into_any_element()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format(_input: &Self::Input, excerpts: &Self::Output) -> String {
|
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
|
||||||
let mut body = "Semantic search results:\n".to_string();
|
match &output {
|
||||||
|
Ok(excerpts) => {
|
||||||
|
if excerpts.len() == 0 {
|
||||||
|
return "No results found".to_string();
|
||||||
|
}
|
||||||
|
|
||||||
for excerpt in excerpts {
|
let mut body = "Semantic search results:\n".to_string();
|
||||||
body.push_str("Excerpt from ");
|
|
||||||
body.push_str(excerpt.path.as_ref());
|
for excerpt in excerpts {
|
||||||
body.push_str(", score ");
|
body.push_str("Excerpt from ");
|
||||||
body.push_str(&excerpt.score.to_string());
|
body.push_str(excerpt.path.as_ref());
|
||||||
body.push_str(":\n");
|
body.push_str(", score ");
|
||||||
body.push_str("~~~\n");
|
body.push_str(&excerpt.score.to_string());
|
||||||
body.push_str(excerpt.text.as_ref());
|
body.push_str(":\n");
|
||||||
body.push_str("~~~\n");
|
body.push_str("~~~\n");
|
||||||
|
body.push_str(excerpt.text.as_ref());
|
||||||
|
body.push_str("~~~\n");
|
||||||
|
}
|
||||||
|
body
|
||||||
|
}
|
||||||
|
Err(err) => format!("Error: {}", err),
|
||||||
}
|
}
|
||||||
body
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use gpui::{AnyElement, AppContext, Task, WindowContext};
|
use gpui::{Task, WindowContext};
|
||||||
use std::{any::Any, collections::HashMap};
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::tool::{
|
use crate::tool::{
|
||||||
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
LanguageModelTool, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct ToolRegistry {
|
pub struct ToolRegistry {
|
||||||
tools: HashMap<String, Box<dyn Fn(&ToolFunctionCall, &AppContext) -> Task<ToolFunctionCall>>>,
|
tools: HashMap<
|
||||||
|
String,
|
||||||
|
Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||||
|
>,
|
||||||
definitions: Vec<ToolFunctionDefinition>,
|
definitions: Vec<ToolFunctionDefinition>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,77 +27,45 @@ impl ToolRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
|
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
|
||||||
fn render<T: 'static + LanguageModelTool>(
|
|
||||||
tool_call_id: &str,
|
|
||||||
input: &Box<dyn Any>,
|
|
||||||
output: &Box<dyn Any>,
|
|
||||||
cx: &mut WindowContext,
|
|
||||||
) -> AnyElement {
|
|
||||||
T::render(
|
|
||||||
tool_call_id,
|
|
||||||
input.as_ref().downcast_ref::<T::Input>().unwrap(),
|
|
||||||
output.as_ref().downcast_ref::<T::Output>().unwrap(),
|
|
||||||
cx,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format<T: 'static + LanguageModelTool>(
|
|
||||||
input: &Box<dyn Any>,
|
|
||||||
output: &Box<dyn Any>,
|
|
||||||
) -> String {
|
|
||||||
T::format(
|
|
||||||
input.as_ref().downcast_ref::<T::Input>().unwrap(),
|
|
||||||
output.as_ref().downcast_ref::<T::Output>().unwrap(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
self.definitions.push(tool.definition());
|
self.definitions.push(tool.definition());
|
||||||
let name = tool.name();
|
let name = tool.name();
|
||||||
let previous = self.tools.insert(
|
let previous = self.tools.insert(
|
||||||
name.clone(),
|
name.clone(),
|
||||||
Box::new(move |tool_call: &ToolFunctionCall, cx: &AppContext| {
|
// registry.call(tool_call, cx)
|
||||||
let name = tool_call.name.clone();
|
Box::new(
|
||||||
let arguments = tool_call.arguments.clone();
|
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||||
let id = tool_call.id.clone();
|
let name = tool_call.name.clone();
|
||||||
|
let arguments = tool_call.arguments.clone();
|
||||||
|
let id = tool_call.id.clone();
|
||||||
|
|
||||||
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
|
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
|
||||||
return Task::ready(ToolFunctionCall {
|
return Task::ready(Ok(ToolFunctionCall {
|
||||||
id,
|
|
||||||
name: name.clone(),
|
|
||||||
arguments,
|
|
||||||
result: Some(ToolFunctionCallResult::ParsingFailed),
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = tool.execute(&input, cx);
|
|
||||||
|
|
||||||
cx.spawn(move |_cx| async move {
|
|
||||||
match result.await {
|
|
||||||
Ok(result) => {
|
|
||||||
let result: T::Output = result;
|
|
||||||
ToolFunctionCall {
|
|
||||||
id,
|
|
||||||
name: name.clone(),
|
|
||||||
arguments,
|
|
||||||
result: Some(ToolFunctionCallResult::Finished {
|
|
||||||
input: Box::new(input),
|
|
||||||
output: Box::new(result),
|
|
||||||
render_fn: render::<T>,
|
|
||||||
format_fn: format::<T>,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_error) => ToolFunctionCall {
|
|
||||||
id,
|
id,
|
||||||
name: name.clone(),
|
name: name.clone(),
|
||||||
arguments,
|
arguments,
|
||||||
result: Some(ToolFunctionCallResult::ExecutionFailed {
|
result: Some(ToolFunctionCallResult::ParsingFailed),
|
||||||
input: Box::new(input),
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tool.execute(&input, cx);
|
||||||
|
|
||||||
|
cx.spawn(move |mut cx| async move {
|
||||||
|
let result: Result<T::Output> = result.await;
|
||||||
|
let for_model = T::format(&input, &result);
|
||||||
|
let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?;
|
||||||
|
|
||||||
|
Ok(ToolFunctionCall {
|
||||||
|
id,
|
||||||
|
name: name.clone(),
|
||||||
|
arguments,
|
||||||
|
result: Some(ToolFunctionCallResult::Finished {
|
||||||
|
view: view.into(),
|
||||||
|
for_model,
|
||||||
}),
|
}),
|
||||||
},
|
})
|
||||||
}
|
})
|
||||||
})
|
},
|
||||||
}),
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
if previous.is_some() {
|
if previous.is_some() {
|
||||||
|
@ -104,7 +75,12 @@ impl ToolRegistry {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn call(&self, tool_call: &ToolFunctionCall, cx: &AppContext) -> Task<ToolFunctionCall> {
|
/// Task yields an error if the window for the given WindowContext is closed before the task completes.
|
||||||
|
pub fn call(
|
||||||
|
&self,
|
||||||
|
tool_call: &ToolFunctionCall,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> Task<Result<ToolFunctionCall>> {
|
||||||
let name = tool_call.name.clone();
|
let name = tool_call.name.clone();
|
||||||
let arguments = tool_call.arguments.clone();
|
let arguments = tool_call.arguments.clone();
|
||||||
let id = tool_call.id.clone();
|
let id = tool_call.id.clone();
|
||||||
|
@ -113,12 +89,12 @@ impl ToolRegistry {
|
||||||
Some(tool) => tool,
|
Some(tool) => tool,
|
||||||
None => {
|
None => {
|
||||||
let name = name.clone();
|
let name = name.clone();
|
||||||
return Task::ready(ToolFunctionCall {
|
return Task::ready(Ok(ToolFunctionCall {
|
||||||
id,
|
id,
|
||||||
name: name.clone(),
|
name: name.clone(),
|
||||||
arguments,
|
arguments,
|
||||||
result: Some(ToolFunctionCallResult::NoSuchTool),
|
result: Some(ToolFunctionCallResult::NoSuchTool),
|
||||||
});
|
}));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -128,12 +104,10 @@ impl ToolRegistry {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use gpui::View;
|
||||||
|
use gpui::{div, prelude::*, Render, TestAppContext};
|
||||||
use schemars::schema_for;
|
use schemars::schema_for;
|
||||||
|
|
||||||
use gpui::{div, AnyElement, Element, ParentElement, TestAppContext, WindowContext};
|
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
@ -155,9 +129,20 @@ mod test {
|
||||||
unit: String,
|
unit: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct WeatherView {
|
||||||
|
result: WeatherResult,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Render for WeatherView {
|
||||||
|
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
|
||||||
|
div().child(format!("temperature: {}", self.result.temperature))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl LanguageModelTool for WeatherTool {
|
impl LanguageModelTool for WeatherTool {
|
||||||
type Input = WeatherQuery;
|
type Input = WeatherQuery;
|
||||||
type Output = WeatherResult;
|
type Output = WeatherResult;
|
||||||
|
type View = WeatherView;
|
||||||
|
|
||||||
fn name(&self) -> String {
|
fn name(&self) -> String {
|
||||||
"get_current_weather".to_string()
|
"get_current_weather".to_string()
|
||||||
|
@ -167,7 +152,11 @@ mod test {
|
||||||
"Fetches the current weather for a given location.".to_string()
|
"Fetches the current weather for a given location.".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute(&self, input: &WeatherQuery, _cx: &AppContext) -> Task<Result<Self::Output>> {
|
fn execute(
|
||||||
|
&self,
|
||||||
|
input: &Self::Input,
|
||||||
|
_cx: &gpui::AppContext,
|
||||||
|
) -> Task<Result<Self::Output>> {
|
||||||
let _location = input.location.clone();
|
let _location = input.location.clone();
|
||||||
let _unit = input.unit.clone();
|
let _unit = input.unit.clone();
|
||||||
|
|
||||||
|
@ -176,25 +165,20 @@ mod test {
|
||||||
Task::ready(Ok(weather))
|
Task::ready(Ok(weather))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render(
|
fn new_view(
|
||||||
_tool_call_id: &str,
|
_tool_call_id: String,
|
||||||
_input: &Self::Input,
|
_input: Self::Input,
|
||||||
output: &Self::Output,
|
result: Result<Self::Output>,
|
||||||
_cx: &mut WindowContext,
|
cx: &mut WindowContext,
|
||||||
) -> AnyElement {
|
) -> View<Self::View> {
|
||||||
div()
|
cx.new_view(|_cx| {
|
||||||
.child(format!(
|
let result = result.unwrap();
|
||||||
"The current temperature in {} is {} {}",
|
WeatherView { result }
|
||||||
output.location, output.temperature, output.unit
|
})
|
||||||
))
|
|
||||||
.into_any()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format(_input: &Self::Input, output: &Self::Output) -> String {
|
fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
|
||||||
format!(
|
serde_json::to_string(&output.as_ref().unwrap()).unwrap()
|
||||||
"The current temperature in {} is {} {}",
|
|
||||||
output.location, output.temperature, output.unit
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,20 +198,20 @@ mod test {
|
||||||
|
|
||||||
registry.register(tool).unwrap();
|
registry.register(tool).unwrap();
|
||||||
|
|
||||||
let _result = cx
|
// let _result = cx
|
||||||
.update(|cx| {
|
// .update(|cx| {
|
||||||
registry.call(
|
// registry.call(
|
||||||
&ToolFunctionCall {
|
// &ToolFunctionCall {
|
||||||
name: "get_current_weather".to_string(),
|
// name: "get_current_weather".to_string(),
|
||||||
arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
|
// arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
|
||||||
.to_string(),
|
// .to_string(),
|
||||||
id: "test-123".to_string(),
|
// id: "test-123".to_string(),
|
||||||
result: None,
|
// result: None,
|
||||||
},
|
// },
|
||||||
cx,
|
// cx,
|
||||||
)
|
// )
|
||||||
})
|
// })
|
||||||
.await;
|
// .await;
|
||||||
|
|
||||||
// assert!(result.is_ok());
|
// assert!(result.is_ok());
|
||||||
// let result = result.unwrap();
|
// let result = result.unwrap();
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
|
use gpui::{AnyElement, AnyView, AppContext, IntoElement as _, Render, Task, View, WindowContext};
|
||||||
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::{
|
use std::fmt::Display;
|
||||||
any::Any,
|
|
||||||
fmt::{Debug, Display},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Default, Deserialize)]
|
#[derive(Default, Deserialize)]
|
||||||
pub struct ToolFunctionCall {
|
pub struct ToolFunctionCall {
|
||||||
|
@ -19,71 +16,29 @@ pub struct ToolFunctionCall {
|
||||||
pub enum ToolFunctionCallResult {
|
pub enum ToolFunctionCallResult {
|
||||||
NoSuchTool,
|
NoSuchTool,
|
||||||
ParsingFailed,
|
ParsingFailed,
|
||||||
ExecutionFailed {
|
Finished { for_model: String, view: AnyView },
|
||||||
input: Box<dyn Any>,
|
|
||||||
},
|
|
||||||
Finished {
|
|
||||||
input: Box<dyn Any>,
|
|
||||||
output: Box<dyn Any>,
|
|
||||||
render_fn: fn(
|
|
||||||
// tool_call_id
|
|
||||||
&str,
|
|
||||||
// LanguageModelTool::Input
|
|
||||||
&Box<dyn Any>,
|
|
||||||
// LanguageModelTool::Output
|
|
||||||
&Box<dyn Any>,
|
|
||||||
&mut WindowContext,
|
|
||||||
) -> AnyElement,
|
|
||||||
format_fn: fn(
|
|
||||||
// LanguageModelTool::Input
|
|
||||||
&Box<dyn Any>,
|
|
||||||
// LanguageModelTool::Output
|
|
||||||
&Box<dyn Any>,
|
|
||||||
) -> String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolFunctionCallResult {
|
impl ToolFunctionCallResult {
|
||||||
pub fn render(
|
pub fn format(&self, name: &String) -> String {
|
||||||
&self,
|
|
||||||
tool_name: &str,
|
|
||||||
tool_call_id: &str,
|
|
||||||
cx: &mut WindowContext,
|
|
||||||
) -> AnyElement {
|
|
||||||
match self {
|
match self {
|
||||||
ToolFunctionCallResult::NoSuchTool => {
|
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
|
||||||
div().child(format!("no such tool {tool_name}")).into_any()
|
ToolFunctionCallResult::ParsingFailed => {
|
||||||
|
format!("Unable to parse arguments for {name}")
|
||||||
}
|
}
|
||||||
ToolFunctionCallResult::ParsingFailed => div()
|
ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(),
|
||||||
.child(format!("failed to parse input for tool {tool_name}"))
|
|
||||||
.into_any(),
|
|
||||||
ToolFunctionCallResult::ExecutionFailed { .. } => div()
|
|
||||||
.child(format!("failed to execute tool {tool_name}"))
|
|
||||||
.into_any(),
|
|
||||||
ToolFunctionCallResult::Finished {
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
render_fn,
|
|
||||||
..
|
|
||||||
} => render_fn(tool_call_id, input, output, cx),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(&self, tool: &str) -> String {
|
pub fn into_any_element(&self, name: &String) -> AnyElement {
|
||||||
match self {
|
match self {
|
||||||
ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"),
|
ToolFunctionCallResult::NoSuchTool => {
|
||||||
|
format!("Language Model attempted to call {name}").into_any_element()
|
||||||
|
}
|
||||||
ToolFunctionCallResult::ParsingFailed => {
|
ToolFunctionCallResult::ParsingFailed => {
|
||||||
format!("failed to parse input for tool {tool}")
|
format!("Language Model called {name} with bad arguments").into_any_element()
|
||||||
}
|
}
|
||||||
ToolFunctionCallResult::ExecutionFailed { input: _input } => {
|
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
|
||||||
format!("failed to execute tool {tool}")
|
|
||||||
}
|
|
||||||
ToolFunctionCallResult::Finished {
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
format_fn,
|
|
||||||
..
|
|
||||||
} => format_fn(input, output),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,19 +60,6 @@ impl Display for ToolFunctionDefinition {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for ToolFunctionDefinition {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
let schema = serde_json::to_string(&self.parameters).ok();
|
|
||||||
let schema = schema.unwrap_or("None".to_string());
|
|
||||||
|
|
||||||
f.debug_struct("ToolFunctionDefinition")
|
|
||||||
.field("name", &self.name)
|
|
||||||
.field("description", &self.description)
|
|
||||||
.field("parameters", &schema)
|
|
||||||
.finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait LanguageModelTool {
|
pub trait LanguageModelTool {
|
||||||
/// The input type that will be passed in to `execute` when the tool is called
|
/// The input type that will be passed in to `execute` when the tool is called
|
||||||
/// by the language model.
|
/// by the language model.
|
||||||
|
@ -126,6 +68,8 @@ pub trait LanguageModelTool {
|
||||||
/// The output returned by executing the tool.
|
/// The output returned by executing the tool.
|
||||||
type Output: 'static;
|
type Output: 'static;
|
||||||
|
|
||||||
|
type View: Render;
|
||||||
|
|
||||||
/// The name of the tool is exposed to the language model to allow
|
/// The name of the tool is exposed to the language model to allow
|
||||||
/// the model to pick which tools to use. As this name is used to
|
/// the model to pick which tools to use. As this name is used to
|
||||||
/// identify the tool within a tool registry, it should be unique.
|
/// identify the tool within a tool registry, it should be unique.
|
||||||
|
@ -149,12 +93,12 @@ pub trait LanguageModelTool {
|
||||||
/// Execute the tool
|
/// Execute the tool
|
||||||
fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
|
fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
|
||||||
|
|
||||||
fn render(
|
fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
|
||||||
tool_call_id: &str,
|
|
||||||
input: &Self::Input,
|
|
||||||
output: &Self::Output,
|
|
||||||
cx: &mut WindowContext,
|
|
||||||
) -> AnyElement;
|
|
||||||
|
|
||||||
fn format(input: &Self::Input, output: &Self::Output) -> String;
|
fn new_view(
|
||||||
|
tool_call_id: String,
|
||||||
|
input: Self::Input,
|
||||||
|
output: Result<Self::Output>,
|
||||||
|
cx: &mut WindowContext,
|
||||||
|
) -> View<Self::View>;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue