Implement serialization of assistant conversations, including tool calls and attachments (#11577)
Release Notes: - N/A --------- Co-authored-by: Kyle <kylek@zed.dev> Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
parent
24ffa0fcf3
commit
a7aa2578e1
12 changed files with 585 additions and 253 deletions
|
@ -2,9 +2,12 @@ mod attachment_registry;
|
|||
mod project_context;
|
||||
mod tool_registry;
|
||||
|
||||
pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
|
||||
pub use attachment_registry::{
|
||||
AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment,
|
||||
};
|
||||
pub use project_context::ProjectContext;
|
||||
pub use tool_registry::{
|
||||
tool_running_placeholder, LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition,
|
||||
tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall,
|
||||
SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
|
||||
ToolOutput, ToolRegistry,
|
||||
};
|
||||
|
|
|
@ -3,6 +3,8 @@ use anyhow::{anyhow, Result};
|
|||
use collections::HashMap;
|
||||
use futures::future::join_all;
|
||||
use gpui::{AnyView, Render, Task, View, WindowContext};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
sync::{
|
||||
|
@ -17,24 +19,34 @@ pub struct AttachmentRegistry {
|
|||
}
|
||||
|
||||
pub trait LanguageModelAttachment {
|
||||
type Output: 'static;
|
||||
type Output: DeserializeOwned + Serialize + 'static;
|
||||
type View: Render + ToolOutput;
|
||||
|
||||
fn name(&self) -> Arc<str>;
|
||||
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||
|
||||
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
|
||||
fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
|
||||
}
|
||||
|
||||
/// A collected attachment from running an attachment tool
|
||||
pub struct UserAttachment {
|
||||
pub view: AnyView,
|
||||
name: Arc<str>,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedUserAttachment {
|
||||
name: Arc<str>,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
}
|
||||
|
||||
/// Internal representation of an attachment tool to allow us to treat them dynamically
|
||||
struct RegisteredAttachment {
|
||||
name: Arc<str>,
|
||||
enabled: AtomicBool,
|
||||
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
|
||||
deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
|
||||
}
|
||||
|
||||
impl AttachmentRegistry {
|
||||
|
@ -45,24 +57,65 @@ impl AttachmentRegistry {
|
|||
}
|
||||
|
||||
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
|
||||
let call = Box::new(move |cx: &mut WindowContext| {
|
||||
let result = attachment.run(cx);
|
||||
let attachment = Arc::new(attachment);
|
||||
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<A::Output> = result.await;
|
||||
let view = cx.update(|cx| A::view(result, cx))?;
|
||||
let call = Box::new({
|
||||
let attachment = attachment.clone();
|
||||
move |cx: &mut WindowContext| {
|
||||
let result = attachment.run(cx);
|
||||
let attachment = attachment.clone();
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<A::Output> = result.await;
|
||||
let serialized_output =
|
||||
result
|
||||
.as_ref()
|
||||
.map_err(ToString::to_string)
|
||||
.and_then(|output| {
|
||||
Ok(RawValue::from_string(
|
||||
serde_json::to_string(output).map_err(|e| e.to_string())?,
|
||||
)
|
||||
.unwrap())
|
||||
});
|
||||
|
||||
let view = cx.update(|cx| attachment.view(result, cx))?;
|
||||
|
||||
Ok(UserAttachment {
|
||||
name: attachment.name(),
|
||||
view: view.into(),
|
||||
generate_fn: generate::<A>,
|
||||
serialized_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
let deserialize = Box::new({
|
||||
let attachment = attachment.clone();
|
||||
move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
|
||||
let serialized_output = saved_attachment.serialized_output.clone();
|
||||
let output = match &serialized_output {
|
||||
Ok(serialized_output) => {
|
||||
Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
|
||||
}
|
||||
Err(error) => Err(anyhow!("{error}")),
|
||||
};
|
||||
let view = attachment.view(output, cx).into();
|
||||
|
||||
Ok(UserAttachment {
|
||||
view: view.into(),
|
||||
name: saved_attachment.name.clone(),
|
||||
view,
|
||||
serialized_output,
|
||||
generate_fn: generate::<A>,
|
||||
})
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
self.registered_attachments.insert(
|
||||
TypeId::of::<A>(),
|
||||
RegisteredAttachment {
|
||||
name: attachment.name(),
|
||||
call,
|
||||
deserialize,
|
||||
enabled: AtomicBool::new(true),
|
||||
},
|
||||
);
|
||||
|
@ -134,6 +187,35 @@ impl AttachmentRegistry {
|
|||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn serialize_user_attachment(
|
||||
&self,
|
||||
user_attachment: &UserAttachment,
|
||||
) -> SavedUserAttachment {
|
||||
SavedUserAttachment {
|
||||
name: user_attachment.name.clone(),
|
||||
serialized_output: user_attachment.serialized_output.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_user_attachment(
|
||||
&self,
|
||||
saved_user_attachment: SavedUserAttachment,
|
||||
cx: &mut WindowContext,
|
||||
) -> Result<UserAttachment> {
|
||||
if let Some(registered_attachment) = self
|
||||
.registered_attachments
|
||||
.values()
|
||||
.find(|attachment| attachment.name == saved_user_attachment.name)
|
||||
{
|
||||
(registered_attachment.deserialize)(&saved_user_attachment, cx)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"no attachment tool for name {}",
|
||||
saved_user_attachment.name
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserAttachment {
|
||||
|
|
|
@ -1,41 +1,60 @@
|
|||
use crate::ProjectContext;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{
|
||||
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
|
||||
};
|
||||
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::{value::RawValue, Value};
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::HashMap,
|
||||
fmt::Display,
|
||||
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::ProjectContext;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
registered_tools: HashMap<String, RegisteredTool>,
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
#[derive(Default)]
|
||||
pub struct ToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
#[serde(skip)]
|
||||
pub result: Option<ToolFunctionCallResult>,
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub struct SavedToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
pub result: Option<SavedToolFunctionCallResult>,
|
||||
}
|
||||
|
||||
pub enum ToolFunctionCallResult {
|
||||
NoSuchTool,
|
||||
ParsingFailed,
|
||||
Finished {
|
||||
view: AnyView,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum SavedToolFunctionCallResult {
|
||||
NoSuchTool,
|
||||
ParsingFailed,
|
||||
Finished {
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolFunctionDefinition {
|
||||
pub name: String,
|
||||
|
@ -46,10 +65,10 @@ pub struct ToolFunctionDefinition {
|
|||
pub trait LanguageModelTool {
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
type Input: DeserializeOwned + JsonSchema;
|
||||
|
||||
/// The output returned by executing the tool.
|
||||
type Output: 'static;
|
||||
type Output: DeserializeOwned + Serialize + 'static;
|
||||
|
||||
type View: Render + ToolOutput;
|
||||
|
||||
|
@ -80,7 +99,8 @@ pub trait LanguageModelTool {
|
|||
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||
|
||||
/// A view of the output of running the tool, for displaying to the user.
|
||||
fn output_view(
|
||||
fn view(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
|
@ -102,7 +122,8 @@ pub trait ToolOutput: Sized {
|
|||
struct RegisteredTool {
|
||||
enabled: AtomicBool,
|
||||
type_id: TypeId,
|
||||
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
|
||||
render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
|
||||
definition: ToolFunctionDefinition,
|
||||
}
|
||||
|
@ -162,23 +183,125 @@ impl ToolRegistry {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
|
||||
SavedToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
result: call.result.as_ref().map(|result| match result {
|
||||
ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
|
||||
ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
|
||||
ToolFunctionCallResult::Finished {
|
||||
serialized_output, ..
|
||||
} => SavedToolFunctionCallResult::Finished {
|
||||
serialized_output: match serialized_output {
|
||||
Ok(value) => Ok(value.clone()),
|
||||
Err(e) => Err(e.to_string()),
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_tool_call(
|
||||
&self,
|
||||
call: &SavedToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> ToolFunctionCall {
|
||||
if let Some(tool) = &self.registered_tools.get(&call.name) {
|
||||
(tool.deserialize)(call, cx)
|
||||
} else {
|
||||
ToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
result: Some(ToolFunctionCallResult::NoSuchTool),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(
|
||||
&mut self,
|
||||
tool: T,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Result<()> {
|
||||
let name = tool.name();
|
||||
let tool = Arc::new(tool);
|
||||
let registered_tool = RegisteredTool {
|
||||
type_id: TypeId::of::<T>(),
|
||||
definition: tool.definition(),
|
||||
enabled: AtomicBool::new(true),
|
||||
call: Box::new(
|
||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||
deserialize: Box::new({
|
||||
let tool = tool.clone();
|
||||
move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
|
||||
let id = tool_call.id.clone();
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
|
||||
return ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::ParsingFailed),
|
||||
};
|
||||
};
|
||||
|
||||
let result = match &tool_call.result {
|
||||
Some(result) => match result {
|
||||
SavedToolFunctionCallResult::NoSuchTool => {
|
||||
Some(ToolFunctionCallResult::NoSuchTool)
|
||||
}
|
||||
SavedToolFunctionCallResult::ParsingFailed => {
|
||||
Some(ToolFunctionCallResult::ParsingFailed)
|
||||
}
|
||||
SavedToolFunctionCallResult::Finished { serialized_output } => {
|
||||
let output = match serialized_output {
|
||||
Ok(value) => {
|
||||
match serde_json::from_str::<T::Output>(value.get()) {
|
||||
Ok(value) => Ok(value),
|
||||
Err(_) => {
|
||||
return ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(
|
||||
ToolFunctionCallResult::ParsingFailed,
|
||||
),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => Err(anyhow!("{e}")),
|
||||
};
|
||||
|
||||
let view = tool.view(input, output, cx).into();
|
||||
Some(ToolFunctionCallResult::Finished {
|
||||
serialized_output: serialized_output.clone(),
|
||||
generate_fn: generate::<T>,
|
||||
view,
|
||||
})
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
ToolFunctionCall {
|
||||
id: tool_call.id.clone(),
|
||||
name: name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
result,
|
||||
}
|
||||
}
|
||||
}),
|
||||
execute: Box::new({
|
||||
let tool = tool.clone();
|
||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||
let id = tool_call.id.clone();
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
|
||||
return Task::ready(Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
|
@ -188,23 +311,33 @@ impl ToolRegistry {
|
|||
};
|
||||
|
||||
let result = tool.execute(&input, cx);
|
||||
|
||||
let tool = tool.clone();
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<T::Output> = result.await;
|
||||
let view = cx.update(|cx| T::output_view(input, result, cx))?;
|
||||
let result = result.await;
|
||||
let serialized_output = result
|
||||
.as_ref()
|
||||
.map_err(ToString::to_string)
|
||||
.and_then(|output| {
|
||||
Ok(RawValue::from_string(
|
||||
serde_json::to_string(output).map_err(|e| e.to_string())?,
|
||||
)
|
||||
.unwrap())
|
||||
});
|
||||
let view = cx.update(|cx| tool.view(input, result, cx))?;
|
||||
|
||||
Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::Finished {
|
||||
serialized_output,
|
||||
view: view.into(),
|
||||
generate_fn: generate::<T>,
|
||||
}),
|
||||
})
|
||||
})
|
||||
},
|
||||
),
|
||||
}
|
||||
}),
|
||||
render_running: render_running::<T>,
|
||||
};
|
||||
|
||||
|
@ -259,7 +392,7 @@ impl ToolRegistry {
|
|||
}
|
||||
};
|
||||
|
||||
(tool.call)(tool_call, cx)
|
||||
(tool.execute)(tool_call, cx)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -275,9 +408,9 @@ impl ToolFunctionCallResult {
|
|||
ToolFunctionCallResult::ParsingFailed => {
|
||||
format!("Unable to parse arguments for {name}")
|
||||
}
|
||||
ToolFunctionCallResult::Finished { generate_fn, view } => {
|
||||
(generate_fn)(view.clone(), project, cx)
|
||||
}
|
||||
ToolFunctionCallResult::Finished {
|
||||
generate_fn, view, ..
|
||||
} => (generate_fn)(view.clone(), project, cx),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -373,7 +506,8 @@ mod test {
|
|||
Task::ready(Ok(weather))
|
||||
}
|
||||
|
||||
fn output_view(
|
||||
fn view(
|
||||
&self,
|
||||
_input: Self::Input,
|
||||
result: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue