Implement serialization of assistant conversations, including tool calls and attachments (#11577)

Release Notes:

- N/A

---------

Co-authored-by: Kyle <kylek@zed.dev>
Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-05-08 14:52:15 -07:00 committed by GitHub
parent 24ffa0fcf3
commit a7aa2578e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 585 additions and 253 deletions

View file

@ -1,41 +1,60 @@
use crate::ProjectContext;
use anyhow::{anyhow, Result};
use gpui::{
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
};
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::Deserialize;
use serde_json::Value;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{value::RawValue, Value};
use std::{
any::TypeId,
collections::HashMap,
fmt::Display,
sync::atomic::{AtomicBool, Ordering::SeqCst},
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use crate::ProjectContext;
pub struct ToolRegistry {
registered_tools: HashMap<String, RegisteredTool>,
}
#[derive(Default, Deserialize)]
#[derive(Default)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
#[serde(skip)]
pub result: Option<ToolFunctionCallResult>,
}
#[derive(Default, Serialize, Deserialize)]
pub struct SavedToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
pub result: Option<SavedToolFunctionCallResult>,
}
pub enum ToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished {
view: AnyView,
serialized_output: Result<Box<RawValue>, String>,
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
},
}
#[derive(Serialize, Deserialize)]
pub enum SavedToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished {
serialized_output: Result<Box<RawValue>, String>,
},
}
#[derive(Clone)]
pub struct ToolFunctionDefinition {
pub name: String,
@ -46,10 +65,10 @@ pub struct ToolFunctionDefinition {
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: for<'de> Deserialize<'de> + JsonSchema;
type Input: DeserializeOwned + JsonSchema;
/// The output returned by executing the tool.
type Output: 'static;
type Output: DeserializeOwned + Serialize + 'static;
type View: Render + ToolOutput;
@ -80,7 +99,8 @@ pub trait LanguageModelTool {
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
/// A view of the output of running the tool, for displaying to the user.
fn output_view(
fn view(
&self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
@ -102,7 +122,8 @@ pub trait ToolOutput: Sized {
struct RegisteredTool {
enabled: AtomicBool,
type_id: TypeId,
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
definition: ToolFunctionDefinition,
}
@ -162,23 +183,125 @@ impl ToolRegistry {
}
}
pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
SavedToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: call.result.as_ref().map(|result| match result {
ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
ToolFunctionCallResult::Finished {
serialized_output, ..
} => SavedToolFunctionCallResult::Finished {
serialized_output: match serialized_output {
Ok(value) => Ok(value.clone()),
Err(e) => Err(e.to_string()),
},
},
}),
}
}
pub fn deserialize_tool_call(
&self,
call: &SavedToolFunctionCall,
cx: &mut WindowContext,
) -> ToolFunctionCall {
if let Some(tool) = &self.registered_tools.get(&call.name) {
(tool.deserialize)(call, cx)
} else {
ToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
result: Some(ToolFunctionCallResult::NoSuchTool),
}
}
}
pub fn register<T: 'static + LanguageModelTool>(
&mut self,
tool: T,
_cx: &mut WindowContext,
) -> Result<()> {
let name = tool.name();
let tool = Arc::new(tool);
let registered_tool = RegisteredTool {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
enabled: AtomicBool::new(true),
call: Box::new(
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
deserialize: Box::new({
let tool = tool.clone();
move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
let id = tool_call.id.clone();
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let id = tool_call.id.clone();
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
return ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::ParsingFailed),
};
};
let result = match &tool_call.result {
Some(result) => match result {
SavedToolFunctionCallResult::NoSuchTool => {
Some(ToolFunctionCallResult::NoSuchTool)
}
SavedToolFunctionCallResult::ParsingFailed => {
Some(ToolFunctionCallResult::ParsingFailed)
}
SavedToolFunctionCallResult::Finished { serialized_output } => {
let output = match serialized_output {
Ok(value) => {
match serde_json::from_str::<T::Output>(value.get()) {
Ok(value) => Ok(value),
Err(_) => {
return ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(
ToolFunctionCallResult::ParsingFailed,
),
};
}
}
}
Err(e) => Err(anyhow!("{e}")),
};
let view = tool.view(input, output, cx).into();
Some(ToolFunctionCallResult::Finished {
serialized_output: serialized_output.clone(),
generate_fn: generate::<T>,
view,
})
}
},
None => None,
};
ToolFunctionCall {
id: tool_call.id.clone(),
name: name.clone(),
arguments: tool_call.arguments.clone(),
result,
}
}
}),
execute: Box::new({
let tool = tool.clone();
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let id = tool_call.id.clone();
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
return Task::ready(Ok(ToolFunctionCall {
id,
name: name.clone(),
@ -188,23 +311,33 @@ impl ToolRegistry {
};
let result = tool.execute(&input, cx);
let tool = tool.clone();
cx.spawn(move |mut cx| async move {
let result: Result<T::Output> = result.await;
let view = cx.update(|cx| T::output_view(input, result, cx))?;
let result = result.await;
let serialized_output = result
.as_ref()
.map_err(ToString::to_string)
.and_then(|output| {
Ok(RawValue::from_string(
serde_json::to_string(output).map_err(|e| e.to_string())?,
)
.unwrap())
});
let view = cx.update(|cx| tool.view(input, result, cx))?;
Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::Finished {
serialized_output,
view: view.into(),
generate_fn: generate::<T>,
}),
})
})
},
),
}
}),
render_running: render_running::<T>,
};
@ -259,7 +392,7 @@ impl ToolRegistry {
}
};
(tool.call)(tool_call, cx)
(tool.execute)(tool_call, cx)
}
}
@ -275,9 +408,9 @@ impl ToolFunctionCallResult {
ToolFunctionCallResult::ParsingFailed => {
format!("Unable to parse arguments for {name}")
}
ToolFunctionCallResult::Finished { generate_fn, view } => {
(generate_fn)(view.clone(), project, cx)
}
ToolFunctionCallResult::Finished {
generate_fn, view, ..
} => (generate_fn)(view.clone(), project, cx),
}
}
@ -373,7 +506,8 @@ mod test {
Task::ready(Ok(weather))
}
fn output_view(
fn view(
&self,
_input: Self::Input,
result: Result<Self::Output>,
cx: &mut WindowContext,