Allow loading a previously-saved conversation

This commit is contained in:
Antonio Scandurra 2023-06-21 16:06:09 +02:00
parent 06701e78aa
commit a011ced698
5 changed files with 279 additions and 64 deletions

View file

@ -22,7 +22,7 @@ util = { path = "../util" }
workspace = { path = "../workspace" } workspace = { path = "../workspace" }
anyhow.workspace = true anyhow.workspace = true
chrono = "0.4" chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true futures.workspace = true
isahc.workspace = true isahc.workspace = true
regex.workspace = true regex.workspace = true

View file

@ -3,6 +3,8 @@ mod assistant_settings;
use anyhow::Result; use anyhow::Result;
pub use assistant::AssistantPanel; pub use assistant::AssistantPanel;
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs; use fs::Fs;
use futures::StreamExt; use futures::StreamExt;
use gpui::AppContext; use gpui::AppContext;
@ -12,7 +14,6 @@ use std::{
fmt::{self, Display}, fmt::{self, Display},
path::PathBuf, path::PathBuf,
sync::Arc, sync::Arc,
time::SystemTime,
}; };
use util::paths::CONVERSATIONS_DIR; use util::paths::CONVERSATIONS_DIR;
@ -24,11 +25,44 @@ struct OpenAIRequest {
stream: bool, stream: bool,
} }
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
struct MessageId(usize);
#[derive(Clone, Debug, Serialize, Deserialize)]
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
status: MessageStatus,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
}
#[derive(Serialize, Deserialize)]
struct SavedMessage {
id: MessageId,
start: usize,
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct SavedConversation { struct SavedConversation {
zed: String, zed: String,
version: String, version: String,
messages: Vec<RequestMessage>, text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
model: String,
}
impl SavedConversation {
const VERSION: &'static str = "0.1.0";
} }
struct SavedConversationMetadata { struct SavedConversationMetadata {

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings}, assistant_settings::{AssistantDockPosition, AssistantSettings},
OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role, SavedConversation, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
SavedConversationMetadata, RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
@ -27,10 +27,18 @@ use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset a
use serde::Deserialize; use serde::Deserialize;
use settings::SettingsStore; use settings::SettingsStore;
use std::{ use std::{
borrow::Cow, cell::RefCell, cmp, env, fmt::Write, io, iter, ops::Range, path::PathBuf, rc::Rc, borrow::Cow,
sync::Arc, time::Duration, cell::RefCell,
cmp, env,
fmt::Write,
io, iter,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::Duration,
}; };
use theme::{ui::IconStyle, IconButton, Theme}; use theme::ui::IconStyle;
use util::{ use util::{
channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, truncate_and_trailoff, ResultExt, channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, truncate_and_trailoff, ResultExt,
TryFutureExt, TryFutureExt,
@ -68,7 +76,7 @@ pub fn init(cx: &mut AppContext) {
|workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| { |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
if let Some(this) = workspace.panel::<AssistantPanel>(cx) { if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.add_conversation(cx); this.new_conversation(cx);
}) })
} }
@ -187,13 +195,8 @@ impl AssistantPanel {
}) })
} }
fn add_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> { fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
let focus = self.has_focus(cx);
let editor = cx.add_view(|cx| { let editor = cx.add_view(|cx| {
if focus {
cx.focus_self();
}
ConversationEditor::new( ConversationEditor::new(
self.api_key.clone(), self.api_key.clone(),
self.languages.clone(), self.languages.clone(),
@ -201,14 +204,24 @@ impl AssistantPanel {
cx, cx,
) )
}); });
self.add_conversation(editor.clone(), cx);
editor
}
fn add_conversation(
&mut self,
editor: ViewHandle<ConversationEditor>,
cx: &mut ViewContext<Self>,
) {
self.subscriptions self.subscriptions
.push(cx.subscribe(&editor, Self::handle_conversation_editor_event)); .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
self.active_conversation_index = Some(self.conversation_editors.len()); self.active_conversation_index = Some(self.conversation_editors.len());
self.conversation_editors.push(editor.clone()); self.conversation_editors.push(editor.clone());
if self.has_focus(cx) {
cx.focus(&editor);
}
cx.notify(); cx.notify();
editor
} }
fn handle_conversation_editor_event( fn handle_conversation_editor_event(
@ -264,9 +277,28 @@ impl AssistantPanel {
} }
fn render_hamburger_button(style: &IconStyle) -> impl Element<Self> { fn render_hamburger_button(style: &IconStyle) -> impl Element<Self> {
enum ListConversations {}
Svg::for_style(style.icon.clone()) Svg::for_style(style.icon.clone())
.contained() .contained()
.with_style(style.container) .with_style(style.container)
.mouse::<ListConversations>(0)
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, |_, this: &mut Self, cx| {
this.active_conversation_index = None;
cx.notify();
})
}
fn render_plus_button(style: &IconStyle) -> impl Element<Self> {
enum AddConversation {}
Svg::for_style(style.icon.clone())
.contained()
.with_style(style.container)
.mouse::<AddConversation>(0)
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, |_, this: &mut Self, cx| {
this.new_conversation(cx);
})
} }
fn render_saved_conversation( fn render_saved_conversation(
@ -274,20 +306,23 @@ impl AssistantPanel {
index: usize, index: usize,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> impl Element<Self> { ) -> impl Element<Self> {
let conversation = &self.saved_conversations[index];
let path = conversation.path.clone();
MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| { MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
let style = &theme::current(cx).assistant.saved_conversation; let style = &theme::current(cx).assistant.saved_conversation;
let conversation = &self.saved_conversations[index];
Flex::row() Flex::row()
.with_child( .with_child(
Label::new( Label::new(
conversation.mtime.format("%c").to_string(), conversation.mtime.format("%F %I:%M%p").to_string(),
style.saved_at.text.clone(), style.saved_at.text.clone(),
) )
.aligned()
.contained() .contained()
.with_style(style.saved_at.container), .with_style(style.saved_at.container),
) )
.with_child( .with_child(
Label::new(conversation.title.clone(), style.title.text.clone()) Label::new(conversation.title.clone(), style.title.text.clone())
.aligned()
.contained() .contained()
.with_style(style.title.container), .with_style(style.title.container),
) )
@ -295,7 +330,48 @@ impl AssistantPanel {
.with_style(*style.container.style_for(state, false)) .with_style(*style.container.style_for(state, false))
}) })
.with_cursor_style(CursorStyle::PointingHand) .with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, |_, this, cx| {}) .on_click(MouseButton::Left, move |_, this, cx| {
this.open_conversation(path.clone(), cx)
.detach_and_log_err(cx)
})
}
fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
if let Some(ix) = self.conversation_editor_index_for_path(&path, cx) {
self.active_conversation_index = Some(ix);
cx.notify();
return Task::ready(Ok(()));
}
let fs = self.fs.clone();
let conversation = Conversation::load(
path.clone(),
self.api_key.clone(),
self.languages.clone(),
self.fs.clone(),
cx,
);
cx.spawn(|this, mut cx| async move {
let conversation = conversation.await?;
this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened
// the same conversation, we don't want to open it again.
if let Some(ix) = this.conversation_editor_index_for_path(&path, cx) {
this.active_conversation_index = Some(ix);
} else {
let editor = cx
.add_view(|cx| ConversationEditor::from_conversation(conversation, fs, cx));
this.add_conversation(editor, cx);
}
})?;
Ok(())
})
}
fn conversation_editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
self.conversation_editors
.iter()
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
} }
} }
@ -341,18 +417,22 @@ impl View for AssistantPanel {
.with_style(style.api_key_prompt.container) .with_style(style.api_key_prompt.container)
.aligned() .aligned()
.into_any() .into_any()
} else if let Some(editor) = self.active_conversation_editor() { } else {
Flex::column() Flex::column()
.with_child( .with_child(
Flex::row() Flex::row()
.with_child(Self::render_hamburger_button(&style.hamburger_button)) .with_child(
Self::render_hamburger_button(&style.hamburger_button).aligned(),
)
.with_child(Self::render_plus_button(&style.plus_button).aligned())
.contained() .contained()
.with_style(theme.workspace.tab_bar.container) .with_style(theme.workspace.tab_bar.container)
.expanded()
.constrained() .constrained()
.with_height(theme.workspace.tab_bar.height), .with_height(theme.workspace.tab_bar.height),
) )
.with_child(ChildView::new(editor, cx).flex(1., true)) .with_child(if let Some(editor) = self.active_conversation_editor() {
.into_any() ChildView::new(editor, cx).flex(1., true).into_any()
} else { } else {
UniformList::new( UniformList::new(
self.saved_conversations_list_state.clone(), self.saved_conversations_list_state.clone(),
@ -364,6 +444,9 @@ impl View for AssistantPanel {
} }
}, },
) )
.flex(1., true)
.into_any()
})
.into_any() .into_any()
} }
} }
@ -468,7 +551,7 @@ impl Panel for AssistantPanel {
} }
if self.conversation_editors.is_empty() { if self.conversation_editors.is_empty() {
self.add_conversation(cx); self.new_conversation(cx);
} }
} }
} }
@ -598,6 +681,74 @@ impl Conversation {
this this
} }
fn load(
path: PathBuf,
api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
cx: &mut AppContext,
) -> Task<Result<ModelHandle<Self>>> {
cx.spawn(|mut cx| async move {
let saved_conversation = fs.load(&path).await?;
let saved_conversation: SavedConversation = serde_json::from_str(&saved_conversation)?;
let model = saved_conversation.model;
let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0);
let buffer = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, saved_conversation.text, cx);
for message in saved_conversation.messages {
message_anchors.push(MessageAnchor {
id: message.id,
start: buffer.anchor_before(message.start),
});
next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
}
buffer.set_language_registry(language_registry);
cx.spawn_weak(|buffer, mut cx| async move {
let markdown = markdown.await?;
let buffer = buffer
.upgrade(&cx)
.ok_or_else(|| anyhow!("buffer was dropped"))?;
buffer.update(&mut cx, |buffer, cx| {
buffer.set_language(Some(markdown), cx)
});
anyhow::Ok(())
})
.detach_and_log_err(cx);
buffer
});
let conversation = cx.add_model(|cx| {
let mut this = Self {
message_anchors,
messages_metadata: saved_conversation.message_metadata,
next_message_id,
summary: Some(Summary {
text: saved_conversation.summary,
done: true,
}),
pending_summary: Task::ready(None),
completion_count: Default::default(),
pending_completions: Default::default(),
token_count: None,
max_token_count: tiktoken_rs::model::get_context_size(&model),
pending_token_count: Task::ready(None),
model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
api_key,
buffer,
};
this.count_remaining_tokens(cx);
this
});
Ok(conversation)
})
}
fn handle_buffer_event( fn handle_buffer_event(
&mut self, &mut self,
_: ModelHandle<Buffer>, _: ModelHandle<Buffer>,
@ -1122,15 +1273,22 @@ impl Conversation {
}); });
if let Some(summary) = summary { if let Some(summary) = summary {
let conversation = SavedConversation { let conversation = this.read_with(&cx, |this, cx| SavedConversation {
zed: "conversation".into(), zed: "conversation".into(),
version: "0.1".into(), version: SavedConversation::VERSION.into(),
messages: this.read_with(&cx, |this, cx| { text: this.buffer.read(cx).text(),
this.messages(cx) message_metadata: this.messages_metadata.clone(),
.map(|message| message.to_open_ai_message(this.buffer.read(cx))) messages: this
.collect() .message_anchors
}), .iter()
}; .map(|message| SavedMessage {
id: message.id,
start: message.start.to_offset(this.buffer.read(cx)),
})
.collect(),
summary: summary.clone(),
model: this.model.clone(),
});
let path = if let Some(old_path) = old_path { let path = if let Some(old_path) = old_path {
old_path old_path
@ -1195,6 +1353,14 @@ impl ConversationEditor {
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> Self { ) -> Self {
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx)); let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
Self::from_conversation(conversation, fs, cx)
}
fn from_conversation(
conversation: ModelHandle<Conversation>,
fs: Arc<dyn Fs>,
cx: &mut ViewContext<Self>,
) -> Self {
let editor = cx.add_view(|cx| { let editor = cx.add_view(|cx| {
let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx); let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx); editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
@ -1524,7 +1690,7 @@ impl ConversationEditor {
let conversation = panel let conversation = panel
.active_conversation_editor() .active_conversation_editor()
.cloned() .cloned()
.unwrap_or_else(|| panel.add_conversation(cx)); .unwrap_or_else(|| panel.new_conversation(cx));
conversation.update(cx, |conversation, cx| { conversation.update(cx, |conversation, cx| {
conversation conversation
.editor .editor
@ -1693,29 +1859,12 @@ impl Item for ConversationEditor {
} }
} }
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
struct MessageId(usize);
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct MessageAnchor { struct MessageAnchor {
id: MessageId, id: MessageId,
start: language::Anchor, start: language::Anchor,
} }
#[derive(Clone, Debug)]
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
status: MessageStatus,
}
#[derive(Clone, Debug)]
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Message { pub struct Message {
range: Range<usize>, range: Range<usize>,
@ -1733,7 +1882,7 @@ impl Message {
content.extend(buffer.text_for_range(self.range.clone())); content.extend(buffer.text_for_range(self.range.clone()));
RequestMessage { RequestMessage {
role: self.role, role: self.role,
content, content: content.trim_end().into(),
} }
} }
} }
@ -1826,6 +1975,8 @@ async fn stream_completion(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::MessageId;
use super::*; use super::*;
use fs::FakeFs; use fs::FakeFs;
use gpui::{AppContext, TestAppContext}; use gpui::{AppContext, TestAppContext};

View file

@ -995,6 +995,7 @@ pub struct TerminalStyle {
pub struct AssistantStyle { pub struct AssistantStyle {
pub container: ContainerStyle, pub container: ContainerStyle,
pub hamburger_button: IconStyle, pub hamburger_button: IconStyle,
pub plus_button: IconStyle,
pub message_header: ContainerStyle, pub message_header: ContainerStyle,
pub sent_at: ContainedText, pub sent_at: ContainedText,
pub user_sender: Interactive<ContainedText>, pub user_sender: Interactive<ContainedText>,

View file

@ -23,7 +23,36 @@ export default function assistant(colorScheme: ColorScheme) {
height: 15, height: 15,
}, },
}, },
container: {} container: {
margin: { left: 8 },
}
},
plusButton: {
icon: {
color: text(layer, "sans", "default", { size: "sm" }).color,
asset: "icons/plus_12.svg",
dimensions: {
width: 12,
height: 12,
},
},
container: {
margin: { left: 8 },
}
},
savedConversation: {
background: background(layer, "on"),
hover: {
background: background(layer, "on", "hovered"),
},
savedAt: {
margin: { left: 8 },
...text(layer, "sans", "default", { size: "xs" }),
},
title: {
margin: { left: 8 },
...text(layer, "sans", "default", { size: "sm", weight: "bold" }),
}
}, },
userSender: { userSender: {
...text(layer, "sans", "default", { size: "sm", weight: "bold" }), ...text(layer, "sans", "default", { size: "sm", weight: "bold" }),