Streaming tools (#11629)
Stream characters in for tool calls to allow rendering partial input. https://github.com/zed-industries/zed/assets/836375/0f023a4b-9c46-4449-ae69-8b6bcab41673 Release Notes: - N/A --------- Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com> Co-authored-by: Marshall <marshall@zed.dev> Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
27ed0f4273
commit
50c45c7897
13 changed files with 786 additions and 653 deletions
|
@ -16,7 +16,9 @@ anyhow.workspace = true
|
|||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
log.workspace = true
|
||||
project.workspace = true
|
||||
repair_json.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
|
|
@ -3,11 +3,11 @@ mod project_context;
|
|||
mod tool_registry;
|
||||
|
||||
pub use attachment_registry::{
|
||||
AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment,
|
||||
AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment,
|
||||
UserAttachment,
|
||||
};
|
||||
pub use project_context::ProjectContext;
|
||||
pub use tool_registry::{
|
||||
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall,
|
||||
SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
||||
ToolOutput, ToolRegistry,
|
||||
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState,
|
||||
ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry,
|
||||
};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{ProjectContext, ToolOutput};
|
||||
use crate::ProjectContext;
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use futures::future::join_all;
|
||||
|
@ -18,9 +18,13 @@ pub struct AttachmentRegistry {
|
|||
registered_attachments: HashMap<TypeId, RegisteredAttachment>,
|
||||
}
|
||||
|
||||
pub trait AttachmentOutput {
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||
}
|
||||
|
||||
pub trait LanguageModelAttachment {
|
||||
type Output: DeserializeOwned + Serialize + 'static;
|
||||
type View: Render + ToolOutput;
|
||||
type View: Render + AttachmentOutput;
|
||||
|
||||
fn name(&self) -> Arc<str>;
|
||||
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
use crate::ProjectContext;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{
|
||||
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
|
||||
};
|
||||
use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
|
||||
use repair_json::repair;
|
||||
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::{value::RawValue, Value};
|
||||
use serde_json::value::RawValue;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::HashMap,
|
||||
|
@ -15,6 +14,7 @@ use std::{
|
|||
Arc,
|
||||
},
|
||||
};
|
||||
use ui::ViewContext;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
registered_tools: HashMap<String, RegisteredTool>,
|
||||
|
@ -25,7 +25,25 @@ pub struct ToolFunctionCall {
|
|||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
pub result: Option<ToolFunctionCallResult>,
|
||||
state: ToolFunctionCallState,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub enum ToolFunctionCallState {
|
||||
#[default]
|
||||
Initializing,
|
||||
NoSuchTool,
|
||||
KnownTool(Box<dyn ToolView>),
|
||||
ExecutedTool(Box<dyn ToolView>),
|
||||
}
|
||||
|
||||
pub trait ToolView {
|
||||
fn view(&self) -> AnyView;
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||
fn set_input(&self, input: &str, cx: &mut WindowContext);
|
||||
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
|
||||
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
|
||||
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
|
@ -33,29 +51,19 @@ pub struct SavedToolFunctionCall {
|
|||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
pub result: Option<SavedToolFunctionCallResult>,
|
||||
pub state: SavedToolFunctionCallState,
|
||||
}
|
||||
|
||||
pub enum ToolFunctionCallResult {
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub enum SavedToolFunctionCallState {
|
||||
#[default]
|
||||
Initializing,
|
||||
NoSuchTool,
|
||||
ParsingFailed,
|
||||
Finished {
|
||||
view: AnyView,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
|
||||
},
|
||||
KnownTool,
|
||||
ExecutedTool(Box<RawValue>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum SavedToolFunctionCallResult {
|
||||
NoSuchTool,
|
||||
ParsingFailed,
|
||||
Finished {
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ToolFunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
|
@ -63,14 +71,7 @@ pub struct ToolFunctionDefinition {
|
|||
}
|
||||
|
||||
pub trait LanguageModelTool {
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
type Input: DeserializeOwned + JsonSchema;
|
||||
|
||||
/// The output returned by executing the tool.
|
||||
type Output: DeserializeOwned + Serialize + 'static;
|
||||
|
||||
type View: Render + ToolOutput;
|
||||
type View: ToolOutput;
|
||||
|
||||
/// Returns the name of the tool.
|
||||
///
|
||||
|
@ -86,7 +87,7 @@ pub trait LanguageModelTool {
|
|||
|
||||
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
|
||||
fn definition(&self) -> ToolFunctionDefinition {
|
||||
let root_schema = schema_for!(Self::Input);
|
||||
let root_schema = schema_for!(<Self::View as ToolOutput>::Input);
|
||||
|
||||
ToolFunctionDefinition {
|
||||
name: self.name(),
|
||||
|
@ -95,36 +96,46 @@ pub trait LanguageModelTool {
|
|||
}
|
||||
}
|
||||
|
||||
/// Executes the tool with the given input.
|
||||
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||
|
||||
/// A view of the output of running the tool, for displaying to the user.
|
||||
fn view(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View>;
|
||||
|
||||
fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
|
||||
tool_running_placeholder()
|
||||
}
|
||||
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
|
||||
}
|
||||
|
||||
pub fn tool_running_placeholder() -> AnyElement {
|
||||
ui::Label::new("Researching...").into_any_element()
|
||||
}
|
||||
|
||||
pub trait ToolOutput: Sized {
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||
pub fn unknown_tool_placeholder() -> AnyElement {
|
||||
ui::Label::new("Unknown tool").into_any_element()
|
||||
}
|
||||
|
||||
pub fn no_such_tool_placeholder() -> AnyElement {
|
||||
ui::Label::new("No such tool").into_any_element()
|
||||
}
|
||||
|
||||
pub trait ToolOutput: Render {
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
type Input: DeserializeOwned + JsonSchema;
|
||||
|
||||
/// The output returned by executing the tool.
|
||||
type SerializedState: DeserializeOwned + Serialize;
|
||||
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
|
||||
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
|
||||
fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
|
||||
|
||||
fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
|
||||
fn deserialize(
|
||||
&mut self,
|
||||
output: Self::SerializedState,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
struct RegisteredTool {
|
||||
enabled: AtomicBool,
|
||||
type_id: TypeId,
|
||||
execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
|
||||
render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
|
||||
build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn ToolView>>,
|
||||
definition: ToolFunctionDefinition,
|
||||
}
|
||||
|
||||
|
@ -161,63 +172,132 @@ impl ToolRegistry {
|
|||
.collect()
|
||||
}
|
||||
|
||||
pub fn render_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
match &tool_call.result {
|
||||
Some(result) => div()
|
||||
.p_2()
|
||||
.child(result.into_any_element(&tool_call.name))
|
||||
.into_any_element(),
|
||||
None => {
|
||||
let tool = self.registered_tools.get(&tool_call.name);
|
||||
pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
|
||||
let tool = self.registered_tools.get(name)?;
|
||||
Some((tool.build_view)(cx))
|
||||
}
|
||||
|
||||
if let Some(tool) = tool {
|
||||
(tool.render_running)(&tool_call, cx)
|
||||
pub fn update_tool_call(
|
||||
&self,
|
||||
call: &mut ToolFunctionCall,
|
||||
name: Option<&str>,
|
||||
arguments: Option<&str>,
|
||||
cx: &mut WindowContext,
|
||||
) {
|
||||
if let Some(name) = name {
|
||||
call.name.push_str(name);
|
||||
}
|
||||
if let Some(arguments) = arguments {
|
||||
if call.arguments.is_empty() {
|
||||
if let Some(view) = self.view_for_tool(&call.name, cx) {
|
||||
call.state = ToolFunctionCallState::KnownTool(view);
|
||||
} else {
|
||||
tool_running_placeholder()
|
||||
call.state = ToolFunctionCallState::NoSuchTool;
|
||||
}
|
||||
}
|
||||
call.arguments.push_str(arguments);
|
||||
|
||||
if let ToolFunctionCallState::KnownTool(view) = &call.state {
|
||||
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
|
||||
view.set_input(&repaired_arguments, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
|
||||
SavedToolFunctionCall {
|
||||
pub fn execute_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Option<Task<Result<()>>> {
|
||||
if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
|
||||
Some(view.execute(cx))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
_cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
match &tool_call.state {
|
||||
ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
|
||||
ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
|
||||
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
|
||||
view.view().into_any_element()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn content_for_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
project_context: &mut ProjectContext,
|
||||
cx: &mut WindowContext,
|
||||
) -> String {
|
||||
match &tool_call.state {
|
||||
ToolFunctionCallState::Initializing => String::new(),
|
||||
ToolFunctionCallState::NoSuchTool => {
|
||||
format!("No such tool: {}", tool_call.name)
|
||||
}
|
||||
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
|
||||
view.generate(project_context, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize_tool_call(
|
||||
&self,
|
||||
call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Result<SavedToolFunctionCall> {
|
||||
Ok(SavedToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
result: call.result.as_ref().map(|result| match result {
|
||||
ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
|
||||
ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
|
||||
ToolFunctionCallResult::Finished {
|
||||
serialized_output, ..
|
||||
} => SavedToolFunctionCallResult::Finished {
|
||||
serialized_output: match serialized_output {
|
||||
Ok(value) => Ok(value.clone()),
|
||||
Err(e) => Err(e.to_string()),
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
state: match &call.state {
|
||||
ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
|
||||
ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
|
||||
ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
|
||||
ToolFunctionCallState::ExecutedTool(view) => {
|
||||
SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn deserialize_tool_call(
|
||||
&self,
|
||||
call: &SavedToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> ToolFunctionCall {
|
||||
if let Some(tool) = &self.registered_tools.get(&call.name) {
|
||||
(tool.deserialize)(call, cx)
|
||||
} else {
|
||||
ToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
result: Some(ToolFunctionCallResult::NoSuchTool),
|
||||
}
|
||||
}
|
||||
) -> Result<ToolFunctionCall> {
|
||||
let Some(tool) = self.registered_tools.get(&call.name) else {
|
||||
return Err(anyhow!("no such tool {}", call.name));
|
||||
};
|
||||
|
||||
Ok(ToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
state: match &call.state {
|
||||
SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
|
||||
SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
|
||||
SavedToolFunctionCallState::KnownTool => {
|
||||
log::error!("Deserialized tool that had not executed");
|
||||
let view = (tool.build_view)(cx);
|
||||
view.set_input(&call.arguments, cx);
|
||||
ToolFunctionCallState::KnownTool(view)
|
||||
}
|
||||
SavedToolFunctionCallState::ExecutedTool(output) => {
|
||||
let view = (tool.build_view)(cx);
|
||||
view.set_input(&call.arguments, cx);
|
||||
view.deserialize_output(output, cx)?;
|
||||
ToolFunctionCallState::ExecutedTool(view)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(
|
||||
|
@ -231,114 +311,7 @@ impl ToolRegistry {
|
|||
type_id: TypeId::of::<T>(),
|
||||
definition: tool.definition(),
|
||||
enabled: AtomicBool::new(true),
|
||||
deserialize: Box::new({
|
||||
let tool = tool.clone();
|
||||
move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
|
||||
let id = tool_call.id.clone();
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
|
||||
return ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::ParsingFailed),
|
||||
};
|
||||
};
|
||||
|
||||
let result = match &tool_call.result {
|
||||
Some(result) => match result {
|
||||
SavedToolFunctionCallResult::NoSuchTool => {
|
||||
Some(ToolFunctionCallResult::NoSuchTool)
|
||||
}
|
||||
SavedToolFunctionCallResult::ParsingFailed => {
|
||||
Some(ToolFunctionCallResult::ParsingFailed)
|
||||
}
|
||||
SavedToolFunctionCallResult::Finished { serialized_output } => {
|
||||
let output = match serialized_output {
|
||||
Ok(value) => {
|
||||
match serde_json::from_str::<T::Output>(value.get()) {
|
||||
Ok(value) => Ok(value),
|
||||
Err(_) => {
|
||||
return ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(
|
||||
ToolFunctionCallResult::ParsingFailed,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => Err(anyhow!("{e}")),
|
||||
};
|
||||
|
||||
let view = tool.view(input, output, cx).into();
|
||||
Some(ToolFunctionCallResult::Finished {
|
||||
serialized_output: serialized_output.clone(),
|
||||
generate_fn: generate::<T>,
|
||||
view,
|
||||
})
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
ToolFunctionCall {
|
||||
id: tool_call.id.clone(),
|
||||
name: name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
result,
|
||||
}
|
||||
}
|
||||
}),
|
||||
execute: Box::new({
|
||||
let tool = tool.clone();
|
||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||
let id = tool_call.id.clone();
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
|
||||
return Task::ready(Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::ParsingFailed),
|
||||
}));
|
||||
};
|
||||
|
||||
let result = tool.execute(&input, cx);
|
||||
let tool = tool.clone();
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result = result.await;
|
||||
let serialized_output = result
|
||||
.as_ref()
|
||||
.map_err(ToString::to_string)
|
||||
.and_then(|output| {
|
||||
Ok(RawValue::from_string(
|
||||
serde_json::to_string(output).map_err(|e| e.to_string())?,
|
||||
)
|
||||
.unwrap())
|
||||
});
|
||||
let view = cx.update(|cx| tool.view(input, result, cx))?;
|
||||
|
||||
Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::Finished {
|
||||
serialized_output,
|
||||
view: view.into(),
|
||||
generate_fn: generate::<T>,
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
}),
|
||||
render_running: render_running::<T>,
|
||||
build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
|
||||
};
|
||||
|
||||
let previous = self.registered_tools.insert(name.clone(), registered_tool);
|
||||
|
@ -347,83 +320,40 @@ impl ToolRegistry {
|
|||
}
|
||||
|
||||
return Ok(());
|
||||
|
||||
fn render_running<T: LanguageModelTool>(
|
||||
tool_call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
// Attempt to parse the string arguments that are JSON as a JSON value
|
||||
let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
|
||||
|
||||
T::render_running(&maybe_arguments, cx).into_any_element()
|
||||
}
|
||||
|
||||
fn generate<T: LanguageModelTool>(
|
||||
view: AnyView,
|
||||
project: &mut ProjectContext,
|
||||
cx: &mut WindowContext,
|
||||
) -> String {
|
||||
view.downcast::<T::View>()
|
||||
.unwrap()
|
||||
.update(cx, |view, cx| T::View::generate(view, project, cx))
|
||||
}
|
||||
}
|
||||
|
||||
/// Task yields an error if the window for the given WindowContext is closed before the task completes.
|
||||
pub fn call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<ToolFunctionCall>> {
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
|
||||
let tool = match self.registered_tools.get(&name) {
|
||||
Some(tool) => tool,
|
||||
None => {
|
||||
let name = name.clone();
|
||||
return Task::ready(Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::NoSuchTool),
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
(tool.execute)(tool_call, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolFunctionCallResult {
|
||||
pub fn generate(
|
||||
&self,
|
||||
name: &String,
|
||||
project: &mut ProjectContext,
|
||||
cx: &mut WindowContext,
|
||||
) -> String {
|
||||
match self {
|
||||
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
|
||||
ToolFunctionCallResult::ParsingFailed => {
|
||||
format!("Unable to parse arguments for {name}")
|
||||
}
|
||||
ToolFunctionCallResult::Finished {
|
||||
generate_fn, view, ..
|
||||
} => (generate_fn)(view.clone(), project, cx),
|
||||
impl<T: ToolOutput> ToolView for View<T> {
|
||||
fn view(&self) -> AnyView {
|
||||
self.clone().into()
|
||||
}
|
||||
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
|
||||
self.update(cx, |view, cx| view.generate(project, cx))
|
||||
}
|
||||
|
||||
fn set_input(&self, input: &str, cx: &mut WindowContext) {
|
||||
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
|
||||
self.update(cx, |view, cx| {
|
||||
view.set_input(input, cx);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn into_any_element(&self, name: &String) -> AnyElement {
|
||||
match self {
|
||||
ToolFunctionCallResult::NoSuchTool => {
|
||||
format!("Language Model attempted to call {name}").into_any_element()
|
||||
}
|
||||
ToolFunctionCallResult::ParsingFailed => {
|
||||
format!("Language Model called {name} with bad arguments").into_any_element()
|
||||
}
|
||||
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
|
||||
}
|
||||
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
|
||||
self.update(cx, |view, cx| view.execute(cx))
|
||||
}
|
||||
|
||||
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
|
||||
let output = self.update(cx, |view, cx| view.serialize(cx));
|
||||
Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
|
||||
}
|
||||
|
||||
fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
|
||||
let state = serde_json::from_str::<T::SerializedState>(output.get())?;
|
||||
self.update(cx, |view, cx| view.deserialize(state, cx))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -453,10 +383,6 @@ mod test {
|
|||
unit: String,
|
||||
}
|
||||
|
||||
struct WeatherTool {
|
||||
current_weather: WeatherResult,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
|
||||
struct WeatherResult {
|
||||
location: String,
|
||||
|
@ -465,24 +391,81 @@ mod test {
|
|||
}
|
||||
|
||||
struct WeatherView {
|
||||
result: WeatherResult,
|
||||
input: Option<WeatherQuery>,
|
||||
result: Option<WeatherResult>,
|
||||
|
||||
// Fake API call
|
||||
current_weather: WeatherResult,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
struct WeatherTool {
|
||||
current_weather: WeatherResult,
|
||||
}
|
||||
|
||||
impl WeatherView {
|
||||
fn new(current_weather: WeatherResult) -> Self {
|
||||
Self {
|
||||
input: None,
|
||||
result: None,
|
||||
current_weather,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for WeatherView {
|
||||
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
|
||||
div().child(format!("temperature: {}", self.result.temperature))
|
||||
match self.result {
|
||||
Some(ref result) => div()
|
||||
.child(format!("temperature: {}", result.temperature))
|
||||
.into_any_element(),
|
||||
None => div().child("Calculating weather...").into_any_element(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolOutput for WeatherView {
|
||||
fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
|
||||
type Input = WeatherQuery;
|
||||
|
||||
type SerializedState = WeatherResult;
|
||||
|
||||
fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
|
||||
serde_json::to_string(&self.result).unwrap()
|
||||
}
|
||||
|
||||
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
|
||||
self.input = Some(input);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||
let input = self.input.as_ref().unwrap();
|
||||
|
||||
let _location = input.location.clone();
|
||||
let _unit = input.unit.clone();
|
||||
|
||||
let weather = self.current_weather.clone();
|
||||
|
||||
self.result = Some(weather);
|
||||
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
|
||||
self.current_weather.clone()
|
||||
}
|
||||
|
||||
fn deserialize(
|
||||
&mut self,
|
||||
output: Self::SerializedState,
|
||||
_cx: &mut ViewContext<Self>,
|
||||
) -> Result<()> {
|
||||
self.current_weather = output;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelTool for WeatherTool {
|
||||
type Input = WeatherQuery;
|
||||
type Output = WeatherResult;
|
||||
type View = WeatherView;
|
||||
|
||||
fn name(&self) -> String {
|
||||
|
@ -493,29 +476,8 @@ mod test {
|
|||
"Fetches the current weather for a given location.".to_string()
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
input: &Self::Input,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
let _location = input.location.clone();
|
||||
let _unit = input.unit.clone();
|
||||
|
||||
let weather = self.current_weather.clone();
|
||||
|
||||
Task::ready(Ok(weather))
|
||||
}
|
||||
|
||||
fn view(
|
||||
&self,
|
||||
_input: Self::Input,
|
||||
result: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View> {
|
||||
cx.new_view(|_cx| {
|
||||
let result = result.unwrap();
|
||||
WeatherView { result }
|
||||
})
|
||||
fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
|
||||
cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -564,18 +526,14 @@ mod test {
|
|||
})
|
||||
);
|
||||
|
||||
let args = json!({
|
||||
"location": "San Francisco",
|
||||
"unit": "Celsius"
|
||||
let view = cx.update(|cx| tool.view(cx));
|
||||
|
||||
cx.update(|cx| {
|
||||
view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
|
||||
});
|
||||
|
||||
let query: WeatherQuery = serde_json::from_value(args).unwrap();
|
||||
let finished = cx.update(|cx| view.execute(cx)).await;
|
||||
|
||||
let result = cx.update(|cx| tool.execute(&query, cx)).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
|
||||
assert_eq!(result, tool.current_weather);
|
||||
assert!(finished.is_ok());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue