Show custom header for assistant messages

This commit is contained in:
Antonio Scandurra 2023-05-29 15:57:55 +02:00
parent 404bebab63
commit 52e8bf2928
5 changed files with 313 additions and 171 deletions

View file

@ -1,13 +1,15 @@
use crate::{OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role};
use anyhow::{anyhow, Result};
use editor::{Editor, MultiBuffer};
use collections::HashMap;
use editor::{Editor, ExcerptId, ExcerptRange, MultiBuffer};
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use gpui::{
actions, elements::*, executor::Background, Action, AppContext, AsyncAppContext, Entity,
ModelHandle, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
ModelContext, ModelHandle, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
WindowContext,
};
use isahc::{http::StatusCode, Request, RequestExt};
use language::{language_settings::SoftWrap, Anchor, Buffer, Language, LanguageRegistry};
use language::{language_settings::SoftWrap, Buffer, Language, LanguageRegistry};
use std::{io, sync::Arc};
use util::{post_inc, ResultExt, TryFutureExt};
use workspace::{
@ -19,8 +21,8 @@ use workspace::{
actions!(assistant, [NewContext, Assist, CancelLastAssist]);
pub fn init(cx: &mut AppContext) {
cx.add_action(Assistant::assist);
cx.capture_action(Assistant::cancel_last_assist);
cx.add_action(AssistantEditor::assist);
cx.capture_action(AssistantEditor::cancel_last_assist);
}
pub enum AssistantPanelEvent {
@ -188,7 +190,7 @@ impl Panel for AssistantPanel {
.await?;
workspace.update(&mut cx, |workspace, cx| {
let editor = Box::new(cx.add_view(|cx| {
Assistant::new(markdown, workspace.app_state().languages.clone(), cx)
AssistantEditor::new(markdown, workspace.app_state().languages.clone(), cx)
}));
Pane::add_item(workspace, &pane, editor, true, focus, None, cx);
})?;
@ -230,38 +232,31 @@ impl Panel for AssistantPanel {
}
struct Assistant {
buffer: ModelHandle<MultiBuffer>,
messages: Vec<Message>,
editor: ViewHandle<Editor>,
messages_by_id: HashMap<ExcerptId, Message>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
markdown: Arc<Language>,
language_registry: Arc<LanguageRegistry>,
}
struct PendingCompletion {
id: usize,
_task: Task<Option<()>>,
impl Entity for Assistant {
type Event = ();
}
impl Assistant {
fn new(
markdown: Arc<Language>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ViewContext<Self>,
cx: &mut ModelContext<Self>,
) -> Self {
let editor = cx.add_view(|cx| {
let multibuffer = cx.add_model(|_| MultiBuffer::new(0));
let mut editor = Editor::for_multibuffer(multibuffer, None, cx);
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
editor.set_show_gutter(false, cx);
editor
});
let mut this = Self {
buffer: cx.add_model(|_| MultiBuffer::new(0)),
messages: Default::default(),
editor,
completion_count: 0,
pending_completions: Vec::new(),
messages_by_id: Default::default(),
completion_count: Default::default(),
pending_completions: Default::default(),
markdown,
language_registry,
};
@ -269,7 +264,7 @@ impl Assistant {
this
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
fn assist(&mut self, cx: &mut ModelContext<Self>) {
let messages = self
.messages
.iter()
@ -285,8 +280,8 @@ impl Assistant {
};
if let Some(api_key) = std::env::var("OPENAI_API_KEY").log_err() {
let stream = stream_completion(api_key, cx.background_executor().clone(), request);
let response_buffer = self.push_message(Role::Assistant, cx);
let stream = stream_completion(api_key, cx.background().clone(), request);
let response = self.push_message(Role::Assistant, cx);
self.push_message(Role::User, cx);
let task = cx.spawn(|this, mut cx| {
async move {
@ -295,7 +290,7 @@ impl Assistant {
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
response_buffer.update(&mut cx, |content, cx| {
response.content.update(&mut cx, |content, cx| {
let text: Arc<str> = choice.delta.content?.into();
content.edit([(content.len()..content.len(), text)], None, cx);
Some(())
@ -306,8 +301,7 @@ impl Assistant {
this.update(&mut cx, |this, _| {
this.pending_completions
.retain(|completion| completion.id != this.completion_count);
})
.ok();
});
anyhow::Ok(())
}
@ -321,45 +315,123 @@ impl Assistant {
}
}
fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
if self.pending_completions.pop().is_none() {
cx.propagate_action();
}
fn cancel_last_assist(&mut self) -> bool {
self.pending_completions.pop().is_some()
}
fn push_message(&mut self, role: Role, cx: &mut ViewContext<Self>) -> ModelHandle<Buffer> {
fn push_message(&mut self, role: Role, cx: &mut ModelContext<Self>) -> Message {
let content = cx.add_model(|cx| {
let mut buffer = Buffer::new(0, "", cx);
buffer.set_language(Some(self.markdown.clone()), cx);
buffer.set_language_registry(self.language_registry.clone());
buffer
});
let excerpt_id = self.buffer.update(cx, |buffer, cx| {
buffer
.push_excerpts(
content.clone(),
vec![ExcerptRange {
context: 0..0,
primary: None,
}],
cx,
)
.pop()
.unwrap()
});
let message = Message {
role,
content: content.clone(),
};
self.messages.push(message);
self.editor.update(cx, |editor, cx| {
editor.buffer().update(cx, |buffer, cx| {
buffer.push_excerpts_with_context_lines(
content.clone(),
vec![Anchor::MIN..Anchor::MAX],
0,
cx,
)
});
});
content
self.messages.push(message.clone());
self.messages_by_id.insert(excerpt_id, message.clone());
message
}
}
impl Entity for Assistant {
struct PendingCompletion {
id: usize,
_task: Task<Option<()>>,
}
struct AssistantEditor {
assistant: ModelHandle<Assistant>,
editor: ViewHandle<Editor>,
}
impl AssistantEditor {
fn new(
markdown: Arc<Language>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ViewContext<Self>,
) -> Self {
let assistant = cx.add_model(|cx| Assistant::new(markdown, language_registry, cx));
let editor = cx.add_view(|cx| {
let mut editor = Editor::for_multibuffer(assistant.read(cx).buffer.clone(), None, cx);
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
editor.set_show_gutter(false, cx);
editor.set_render_excerpt_header(
{
let assistant = assistant.clone();
move |editor, params: editor::RenderExcerptHeaderParams, cx| {
let style = &theme::current(cx).assistant;
if let Some(message) = assistant.read(cx).messages_by_id.get(&params.id) {
let sender = match message.role {
Role::User => Label::new("You", style.user_sender.text.clone())
.contained()
.with_style(style.user_sender.container),
Role::Assistant => {
Label::new("Assistant", style.assistant_sender.text.clone())
.contained()
.with_style(style.assistant_sender.container)
}
Role::System => {
Label::new("System", style.assistant_sender.text.clone())
.contained()
.with_style(style.assistant_sender.container)
}
};
Flex::row()
.with_child(sender)
.aligned()
.left()
.contained()
.with_style(style.header)
.into_any()
} else {
Empty::new().into_any()
}
}
},
cx,
);
editor
});
Self { assistant, editor }
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
self.assistant
.update(cx, |assistant, cx| assistant.assist(cx));
}
fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
if !self
.assistant
.update(cx, |assistant, _| assistant.cancel_last_assist())
{
cx.propagate_action();
}
}
}
impl Entity for AssistantEditor {
type Event = ();
}
impl View for Assistant {
impl View for AssistantEditor {
fn ui_name() -> &'static str {
"ContextEditor"
}
@ -374,7 +446,7 @@ impl View for Assistant {
}
}
impl Item for Assistant {
impl Item for AssistantEditor {
fn tab_content<V: View>(
&self,
_: Option<usize>,
@ -385,6 +457,7 @@ impl Item for Assistant {
}
}
#[derive(Clone)]
struct Message {
role: Role,
content: ModelHandle<Buffer>,