Implement serialization of assistant conversations, including tool calls and attachments (#11577)

Release Notes:

- N/A

---------

Co-authored-by: Kyle <kylek@zed.dev>
Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-05-08 14:52:15 -07:00 committed by GitHub
parent 24ffa0fcf3
commit a7aa2578e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 585 additions and 253 deletions

View file

@ -6,19 +6,14 @@ mod saved_conversation_picker;
mod tools;
pub mod ui;
use crate::saved_conversation::{SavedConversation, SavedMessage, SavedMessageRole};
use crate::saved_conversation_picker::SavedConversationPicker;
use crate::{
attachments::ActiveEditorAttachmentTool,
tools::{CreateBufferTool, ProjectIndexTool},
ui::UserOrAssistant,
};
use crate::ui::UserOrAssistant;
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
use anyhow::{Context, Result};
use assistant_tooling::{
tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
UserAttachment,
};
use attachments::ActiveEditorAttachmentTool;
use client::{proto, Client, UserStore};
use collections::HashMap;
use completion_provider::*;
@ -33,11 +28,13 @@ use gpui::{
use language::{language_settings::SoftWrap, LanguageRegistry};
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
use rich_text::RichText;
use saved_conversation::{SavedAssistantMessagePart, SavedChatMessage, SavedConversation};
use saved_conversation_picker::SavedConversationPicker;
use semantic_index::{CloudEmbeddingProvider, ProjectIndex, ProjectIndexDebugView, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::sync::Arc;
use tools::AnnotationTool;
use tools::{AnnotationTool, CreateBufferTool, ProjectIndexTool};
use ui::{ActiveFileButton, Composer, ProjectIndexButton};
use util::paths::CONVERSATIONS_DIR;
use util::{maybe, paths::EMBEDDINGS_DIR, ResultExt};
@ -506,13 +503,11 @@ impl AssistantChat {
while let Some(delta) = stream.next().await {
let delta = delta?;
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
messages,
..
})) = this.messages.last_mut()
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
this.messages.last_mut()
{
if messages.is_empty() {
messages.push(AssistantMessage {
messages.push(AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
})
@ -563,7 +558,7 @@ impl AssistantChat {
let mut tool_tasks = Vec::new();
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
if let Some(ChatMessage::Assistant(AssistantMessage {
error: message_error,
messages,
..
@ -592,7 +587,7 @@ impl AssistantChat {
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
this.update(cx, |this, cx| {
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
this.messages.last_mut()
{
if let Some(current_message) = messages.last_mut() {
@ -608,19 +603,19 @@ impl AssistantChat {
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
// If the last message is a grouped assistant message, add to the grouped message
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
self.messages.last_mut()
{
messages.push(AssistantMessage {
messages.push(AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
});
return;
}
let message = ChatMessage::Assistant(GroupedAssistantMessage {
let message = ChatMessage::Assistant(AssistantMessage {
id: self.next_message_id.post_inc(),
messages: vec![AssistantMessage {
messages: vec![AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
}],
@ -669,40 +664,30 @@ impl AssistantChat {
*entry = !*entry;
}
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
let messages = self
.messages
.drain(..)
.map(|message| {
let text = match &message {
ChatMessage::User(message) => message.body.read(cx).text(cx),
ChatMessage::Assistant(message) => message
.messages
.iter()
.map(|message| message.body.text.to_string())
.collect::<Vec<_>>()
.join("\n\n"),
};
SavedMessage {
id: message.id(),
role: match message {
ChatMessage::User(_) => SavedMessageRole::User,
ChatMessage::Assistant(_) => SavedMessageRole::Assistant,
},
text,
}
})
.collect::<Vec<_>>();
// Reset the chat for the new conversation.
fn reset(&mut self) {
self.messages.clear();
self.list_state.reset(0);
self.editing_message.take();
self.collapsed_messages.clear();
}
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
let messages = std::mem::take(&mut self.messages)
.into_iter()
.map(|message| self.serialize_message(message, cx))
.collect::<Vec<_>>();
self.reset();
let title = messages
.first()
.map(|message| message.text.clone())
.map(|message| match message {
SavedChatMessage::User { body, .. } => body.clone(),
SavedChatMessage::Assistant { messages, .. } => messages
.first()
.map(|message| message.body.to_string())
.unwrap_or_default(),
})
.unwrap_or_else(|| "A conversation with the assistant.".to_string());
let saved_conversation = SavedConversation {
@ -836,7 +821,7 @@ impl AssistantChat {
}
})
.into_any(),
ChatMessage::Assistant(GroupedAssistantMessage {
ChatMessage::Assistant(AssistantMessage {
id,
messages,
error,
@ -917,7 +902,7 @@ impl AssistantChat {
content: body.read(cx).text(cx),
});
}
ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
ChatMessage::Assistant(AssistantMessage { messages, .. }) => {
for message in messages {
let body = message.body.clone();
@ -971,6 +956,43 @@ impl AssistantChat {
Ok(completion_messages)
})
}
fn serialize_message(
&self,
message: ChatMessage,
cx: &mut ViewContext<AssistantChat>,
) -> SavedChatMessage {
match message {
ChatMessage::User(message) => SavedChatMessage::User {
id: message.id,
body: message.body.read(cx).text(cx),
attachments: message
.attachments
.iter()
.map(|attachment| {
self.attachment_registry
.serialize_user_attachment(attachment)
})
.collect(),
},
ChatMessage::Assistant(message) => SavedChatMessage::Assistant {
id: message.id,
error: message.error,
messages: message
.messages
.iter()
.map(|message| SavedAssistantMessagePart {
body: message.body.text.clone(),
tool_calls: message
.tool_calls
.iter()
.map(|tool_call| self.tool_registry.serialize_tool_call(tool_call))
.collect(),
})
.collect(),
},
}
}
}
impl Render for AssistantChat {
@ -1053,17 +1075,10 @@ impl MessageId {
enum ChatMessage {
User(UserMessage),
Assistant(GroupedAssistantMessage),
Assistant(AssistantMessage),
}
impl ChatMessage {
pub fn id(&self) -> MessageId {
match self {
ChatMessage::User(message) => message.id,
ChatMessage::Assistant(message) => message.id,
}
}
fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
match self {
ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
@ -1073,18 +1088,18 @@ impl ChatMessage {
}
struct UserMessage {
id: MessageId,
body: View<Editor>,
attachments: Vec<UserAttachment>,
pub id: MessageId,
pub body: View<Editor>,
pub attachments: Vec<UserAttachment>,
}
struct AssistantMessagePart {
pub body: RichText,
pub tool_calls: Vec<ToolFunctionCall>,
}
struct AssistantMessage {
body: RichText,
tool_calls: Vec<ToolFunctionCall>,
}
struct GroupedAssistantMessage {
id: MessageId,
messages: Vec<AssistantMessage>,
error: Option<SharedString>,
pub id: MessageId,
pub messages: Vec<AssistantMessagePart>,
pub error: Option<SharedString>,
}