Streaming tools (#11629)

Stream characters in for tool calls to allow rendering partial input.


https://github.com/zed-industries/zed/assets/836375/0f023a4b-9c46-4449-ae69-8b6bcab41673

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Kyle Kelley 2024-05-09 15:57:14 -07:00 committed by GitHub
parent 27ed0f4273
commit 50c45c7897
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 786 additions and 653 deletions

View file

@ -16,7 +16,9 @@ anyhow.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
log.workspace = true
project.workspace = true
repair_json.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true

View file

@ -3,11 +3,11 @@ mod project_context;
mod tool_registry;
pub use attachment_registry::{
AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment,
AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment,
UserAttachment,
};
pub use project_context::ProjectContext;
pub use tool_registry::{
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall,
SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
ToolOutput, ToolRegistry,
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState,
ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry,
};

View file

@ -1,4 +1,4 @@
use crate::{ProjectContext, ToolOutput};
use crate::ProjectContext;
use anyhow::{anyhow, Result};
use collections::HashMap;
use futures::future::join_all;
@ -18,9 +18,13 @@ pub struct AttachmentRegistry {
registered_attachments: HashMap<TypeId, RegisteredAttachment>,
}
pub trait AttachmentOutput {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
}
pub trait LanguageModelAttachment {
type Output: DeserializeOwned + Serialize + 'static;
type View: Render + ToolOutput;
type View: Render + AttachmentOutput;
fn name(&self) -> Arc<str>;
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;

View file

@ -1,11 +1,10 @@
use crate::ProjectContext;
use anyhow::{anyhow, Result};
use gpui::{
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
};
use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
use repair_json::repair;
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{value::RawValue, Value};
use serde_json::value::RawValue;
use std::{
any::TypeId,
collections::HashMap,
@ -15,6 +14,7 @@ use std::{
Arc,
},
};
use ui::ViewContext;
pub struct ToolRegistry {
registered_tools: HashMap<String, RegisteredTool>,
@ -25,7 +25,25 @@ pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
pub result: Option<ToolFunctionCallResult>,
state: ToolFunctionCallState,
}
#[derive(Default)]
pub enum ToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
KnownTool(Box<dyn ToolView>),
ExecutedTool(Box<dyn ToolView>),
}
pub trait ToolView {
fn view(&self) -> AnyView;
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
fn set_input(&self, input: &str, cx: &mut WindowContext);
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
}
#[derive(Default, Serialize, Deserialize)]
@ -33,29 +51,19 @@ pub struct SavedToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
pub result: Option<SavedToolFunctionCallResult>,
pub state: SavedToolFunctionCallState,
}
pub enum ToolFunctionCallResult {
#[derive(Default, Serialize, Deserialize)]
pub enum SavedToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
ParsingFailed,
Finished {
view: AnyView,
serialized_output: Result<Box<RawValue>, String>,
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
},
KnownTool,
ExecutedTool(Box<RawValue>),
}
#[derive(Serialize, Deserialize)]
pub enum SavedToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished {
serialized_output: Result<Box<RawValue>, String>,
},
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
@ -63,14 +71,7 @@ pub struct ToolFunctionDefinition {
}
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: DeserializeOwned + JsonSchema;
/// The output returned by executing the tool.
type Output: DeserializeOwned + Serialize + 'static;
type View: Render + ToolOutput;
type View: ToolOutput;
/// Returns the name of the tool.
///
@ -86,7 +87,7 @@ pub trait LanguageModelTool {
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
let root_schema = schema_for!(Self::Input);
let root_schema = schema_for!(<Self::View as ToolOutput>::Input);
ToolFunctionDefinition {
name: self.name(),
@ -95,36 +96,46 @@ pub trait LanguageModelTool {
}
}
/// Executes the tool with the given input.
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
/// A view of the output of running the tool, for displaying to the user.
fn view(
&self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View>;
fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
tool_running_placeholder()
}
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
}
pub fn tool_running_placeholder() -> AnyElement {
ui::Label::new("Researching...").into_any_element()
}
pub trait ToolOutput: Sized {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
pub fn unknown_tool_placeholder() -> AnyElement {
ui::Label::new("Unknown tool").into_any_element()
}
pub fn no_such_tool_placeholder() -> AnyElement {
ui::Label::new("No such tool").into_any_element()
}
pub trait ToolOutput: Render {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: DeserializeOwned + JsonSchema;
/// The output returned by executing the tool.
type SerializedState: DeserializeOwned + Serialize;
fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
fn deserialize(
&mut self,
output: Self::SerializedState,
cx: &mut ViewContext<Self>,
) -> Result<()>;
}
struct RegisteredTool {
enabled: AtomicBool,
type_id: TypeId,
execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn ToolView>>,
definition: ToolFunctionDefinition,
}
@ -161,63 +172,132 @@ impl ToolRegistry {
.collect()
}
pub fn render_tool_call(
&self,
tool_call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> AnyElement {
match &tool_call.result {
Some(result) => div()
.p_2()
.child(result.into_any_element(&tool_call.name))
.into_any_element(),
None => {
let tool = self.registered_tools.get(&tool_call.name);
pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
let tool = self.registered_tools.get(name)?;
Some((tool.build_view)(cx))
}
if let Some(tool) = tool {
(tool.render_running)(&tool_call, cx)
pub fn update_tool_call(
&self,
call: &mut ToolFunctionCall,
name: Option<&str>,
arguments: Option<&str>,
cx: &mut WindowContext,
) {
if let Some(name) = name {
call.name.push_str(name);
}
if let Some(arguments) = arguments {
if call.arguments.is_empty() {
if let Some(view) = self.view_for_tool(&call.name, cx) {
call.state = ToolFunctionCallState::KnownTool(view);
} else {
tool_running_placeholder()
call.state = ToolFunctionCallState::NoSuchTool;
}
}
call.arguments.push_str(arguments);
if let ToolFunctionCallState::KnownTool(view) = &call.state {
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
view.set_input(&repaired_arguments, cx)
}
}
}
}
pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
SavedToolFunctionCall {
pub fn execute_tool_call(
&self,
tool_call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> Option<Task<Result<()>>> {
if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
Some(view.execute(cx))
} else {
None
}
}
pub fn render_tool_call(
&self,
tool_call: &ToolFunctionCall,
_cx: &mut WindowContext,
) -> AnyElement {
match &tool_call.state {
ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
view.view().into_any_element()
}
}
}
pub fn content_for_tool_call(
&self,
tool_call: &ToolFunctionCall,
project_context: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
match &tool_call.state {
ToolFunctionCallState::Initializing => String::new(),
ToolFunctionCallState::NoSuchTool => {
format!("No such tool: {}", tool_call.name)
}
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
view.generate(project_context, cx)
}
}
}
pub fn serialize_tool_call(
&self,
call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> Result<SavedToolFunctionCall> {
Ok(SavedToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: call.result.as_ref().map(|result| match result {
ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
ToolFunctionCallResult::Finished {
serialized_output, ..
} => SavedToolFunctionCallResult::Finished {
serialized_output: match serialized_output {
Ok(value) => Ok(value.clone()),
Err(e) => Err(e.to_string()),
},
},
}),
}
state: match &call.state {
ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
ToolFunctionCallState::ExecutedTool(view) => {
SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
}
},
})
}
pub fn deserialize_tool_call(
&self,
call: &SavedToolFunctionCall,
cx: &mut WindowContext,
) -> ToolFunctionCall {
if let Some(tool) = &self.registered_tools.get(&call.name) {
(tool.deserialize)(call, cx)
} else {
ToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: Some(ToolFunctionCallResult::NoSuchTool),
}
}
) -> Result<ToolFunctionCall> {
let Some(tool) = self.registered_tools.get(&call.name) else {
return Err(anyhow!("no such tool {}", call.name));
};
Ok(ToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
state: match &call.state {
SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
SavedToolFunctionCallState::KnownTool => {
log::error!("Deserialized tool that had not executed");
let view = (tool.build_view)(cx);
view.set_input(&call.arguments, cx);
ToolFunctionCallState::KnownTool(view)
}
SavedToolFunctionCallState::ExecutedTool(output) => {
let view = (tool.build_view)(cx);
view.set_input(&call.arguments, cx);
view.deserialize_output(output, cx)?;
ToolFunctionCallState::ExecutedTool(view)
}
},
})
}
pub fn register<T: 'static + LanguageModelTool>(
@ -231,114 +311,7 @@ impl ToolRegistry {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
enabled: AtomicBool::new(true),
deserialize: Box::new({
let tool = tool.clone();
move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
let id = tool_call.id.clone();
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
return ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ParsingFailed),
};
};
let result = match &tool_call.result {
Some(result) => match result {
SavedToolFunctionCallResult::NoSuchTool => {
Some(ToolFunctionCallResult::NoSuchTool)
}
SavedToolFunctionCallResult::ParsingFailed => {
Some(ToolFunctionCallResult::ParsingFailed)
}
SavedToolFunctionCallResult::Finished { serialized_output } => {
let output = match serialized_output {
Ok(value) => {
match serde_json::from_str::<T::Output>(value.get()) {
Ok(value) => Ok(value),
Err(_) => {
return ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(
ToolFunctionCallResult::ParsingFailed,
),
};
}
}
}
Err(e) => Err(anyhow!("{e}")),
};
let view = tool.view(input, output, cx).into();
Some(ToolFunctionCallResult::Finished {
serialized_output: serialized_output.clone(),
generate_fn: generate::<T>,
view,
})
}
},
None => None,
};
ToolFunctionCall {
id: tool_call.id.clone(),
name: name.clone(),
arguments: tool_call.arguments.clone(),
result,
}
}
}),
execute: Box::new({
let tool = tool.clone();
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let id = tool_call.id.clone();
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
return Task::ready(Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ParsingFailed),
}));
};
let result = tool.execute(&input, cx);
let tool = tool.clone();
cx.spawn(move |mut cx| async move {
let result = result.await;
let serialized_output = result
.as_ref()
.map_err(ToString::to_string)
.and_then(|output| {
Ok(RawValue::from_string(
serde_json::to_string(output).map_err(|e| e.to_string())?,
)
.unwrap())
});
let view = cx.update(|cx| tool.view(input, result, cx))?;
Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::Finished {
serialized_output,
view: view.into(),
generate_fn: generate::<T>,
}),
})
})
}
}),
render_running: render_running::<T>,
build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
};
let previous = self.registered_tools.insert(name.clone(), registered_tool);
@ -347,83 +320,40 @@ impl ToolRegistry {
}
return Ok(());
fn render_running<T: LanguageModelTool>(
tool_call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> AnyElement {
// Attempt to parse the string arguments that are JSON as a JSON value
let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
T::render_running(&maybe_arguments, cx).into_any_element()
}
fn generate<T: LanguageModelTool>(
view: AnyView,
project: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
view.downcast::<T::View>()
.unwrap()
.update(cx, |view, cx| T::View::generate(view, project, cx))
}
}
/// Task yields an error if the window for the given WindowContext is closed before the task completes.
pub fn call(
&self,
tool_call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> Task<Result<ToolFunctionCall>> {
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
let tool = match self.registered_tools.get(&name) {
Some(tool) => tool,
None => {
let name = name.clone();
return Task::ready(Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::NoSuchTool),
}));
}
};
(tool.execute)(tool_call, cx)
}
}
impl ToolFunctionCallResult {
pub fn generate(
&self,
name: &String,
project: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
match self {
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
ToolFunctionCallResult::ParsingFailed => {
format!("Unable to parse arguments for {name}")
}
ToolFunctionCallResult::Finished {
generate_fn, view, ..
} => (generate_fn)(view.clone(), project, cx),
impl<T: ToolOutput> ToolView for View<T> {
fn view(&self) -> AnyView {
self.clone().into()
}
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
self.update(cx, |view, cx| view.generate(project, cx))
}
fn set_input(&self, input: &str, cx: &mut WindowContext) {
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
self.update(cx, |view, cx| {
view.set_input(input, cx);
cx.notify();
});
}
}
fn into_any_element(&self, name: &String) -> AnyElement {
match self {
ToolFunctionCallResult::NoSuchTool => {
format!("Language Model attempted to call {name}").into_any_element()
}
ToolFunctionCallResult::ParsingFailed => {
format!("Language Model called {name} with bad arguments").into_any_element()
}
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
}
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
self.update(cx, |view, cx| view.execute(cx))
}
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
let output = self.update(cx, |view, cx| view.serialize(cx));
Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
}
fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
let state = serde_json::from_str::<T::SerializedState>(output.get())?;
self.update(cx, |view, cx| view.deserialize(state, cx))?;
Ok(())
}
}
@ -453,10 +383,6 @@ mod test {
unit: String,
}
struct WeatherTool {
current_weather: WeatherResult,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct WeatherResult {
location: String,
@ -465,24 +391,81 @@ mod test {
}
struct WeatherView {
result: WeatherResult,
input: Option<WeatherQuery>,
result: Option<WeatherResult>,
// Fake API call
current_weather: WeatherResult,
}
#[derive(Clone, Serialize)]
struct WeatherTool {
current_weather: WeatherResult,
}
impl WeatherView {
fn new(current_weather: WeatherResult) -> Self {
Self {
input: None,
result: None,
current_weather,
}
}
}
impl Render for WeatherView {
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
div().child(format!("temperature: {}", self.result.temperature))
match self.result {
Some(ref result) => div()
.child(format!("temperature: {}", result.temperature))
.into_any_element(),
None => div().child("Calculating weather...").into_any_element(),
}
}
}
impl ToolOutput for WeatherView {
fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
type Input = WeatherQuery;
type SerializedState = WeatherResult;
fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
serde_json::to_string(&self.result).unwrap()
}
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
self.input = Some(input);
cx.notify();
}
fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
let input = self.input.as_ref().unwrap();
let _location = input.location.clone();
let _unit = input.unit.clone();
let weather = self.current_weather.clone();
self.result = Some(weather);
Task::ready(Ok(()))
}
fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
self.current_weather.clone()
}
fn deserialize(
&mut self,
output: Self::SerializedState,
_cx: &mut ViewContext<Self>,
) -> Result<()> {
self.current_weather = output;
Ok(())
}
}
impl LanguageModelTool for WeatherTool {
type Input = WeatherQuery;
type Output = WeatherResult;
type View = WeatherView;
fn name(&self) -> String {
@ -493,29 +476,8 @@ mod test {
"Fetches the current weather for a given location.".to_string()
}
fn execute(
&self,
input: &Self::Input,
_cx: &mut WindowContext,
) -> Task<Result<Self::Output>> {
let _location = input.location.clone();
let _unit = input.unit.clone();
let weather = self.current_weather.clone();
Task::ready(Ok(weather))
}
fn view(
&self,
_input: Self::Input,
result: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View> {
cx.new_view(|_cx| {
let result = result.unwrap();
WeatherView { result }
})
fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
}
}
@ -564,18 +526,14 @@ mod test {
})
);
let args = json!({
"location": "San Francisco",
"unit": "Celsius"
let view = cx.update(|cx| tool.view(cx));
cx.update(|cx| {
view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
});
let query: WeatherQuery = serde_json::from_value(args).unwrap();
let finished = cx.update(|cx| view.execute(cx)).await;
let result = cx.update(|cx| tool.execute(&query, cx)).await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result, tool.current_weather);
assert!(finished.is_ok());
}
}