assistant2: Render messages as Markdown (#21496)

This PR updates Assistant 2 to render the messages in the thread as
Markdown:

<img width="1138" alt="Screenshot 2024-12-03 at 6 09 27 PM"
src="https://github.com/user-attachments/assets/c1c44fde-1efb-43cf-b9c9-768e6974c753">

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-12-03 18:32:13 -05:00 committed by GitHub
parent ecaf44511c
commit 9f459ba573
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 81 additions and 4 deletions

2
Cargo.lock generated
View file

@ -464,10 +464,12 @@ dependencies = [
"feature_flags", "feature_flags",
"futures 0.3.31", "futures 0.3.31",
"gpui", "gpui",
"language",
"language_model", "language_model",
"language_model_selector", "language_model_selector",
"language_models", "language_models",
"log", "log",
"markdown",
"project", "project",
"proto", "proto",
"serde", "serde",

View file

@ -23,15 +23,17 @@ editor.workspace = true
feature_flags.workspace = true feature_flags.workspace = true
futures.workspace = true futures.workspace = true
gpui.workspace = true gpui.workspace = true
language.workspace = true
language_model.workspace = true language_model.workspace = true
language_model_selector.workspace = true language_model_selector.workspace = true
language_models.workspace = true language_models.workspace = true
log.workspace = true log.workspace = true
markdown.workspace = true
project.workspace = true project.workspace = true
proto.workspace = true proto.workspace = true
settings.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
settings.workspace = true
smol.workspace = true smol.workspace = true
theme.workspace = true theme.workspace = true
ui.workspace = true ui.workspace = true

View file

@ -3,13 +3,19 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use client::zed_urls; use client::zed_urls;
use collections::HashMap;
use gpui::{ use gpui::{
list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter, list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter,
FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels, Subscription, FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels,
Task, View, ViewContext, WeakView, WindowContext, StyleRefinement, Subscription, Task, TextStyleRefinement, View, ViewContext, WeakView,
WindowContext,
}; };
use language::LanguageRegistry;
use language_model::{LanguageModelRegistry, Role}; use language_model::{LanguageModelRegistry, Role};
use language_model_selector::LanguageModelSelector; use language_model_selector::LanguageModelSelector;
use markdown::{Markdown, MarkdownStyle};
use settings::Settings;
use theme::ThemeSettings;
use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, Tab, Tooltip}; use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, Tab, Tooltip};
use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::Workspace; use workspace::Workspace;
@ -32,10 +38,12 @@ pub fn init(cx: &mut AppContext) {
pub struct AssistantPanel { pub struct AssistantPanel {
workspace: WeakView<Workspace>, workspace: WeakView<Workspace>,
language_registry: Arc<LanguageRegistry>,
#[allow(unused)] #[allow(unused)]
thread_store: Model<ThreadStore>, thread_store: Model<ThreadStore>,
thread: Model<Thread>, thread: Model<Thread>,
thread_messages: Vec<MessageId>, thread_messages: Vec<MessageId>,
rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
thread_list_state: ListState, thread_list_state: ListState,
message_editor: View<MessageEditor>, message_editor: View<MessageEditor>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
@ -77,9 +85,11 @@ impl AssistantPanel {
Self { Self {
workspace: workspace.weak_handle(), workspace: workspace.weak_handle(),
language_registry: workspace.project().read(cx).languages().clone(),
thread_store, thread_store,
thread: thread.clone(), thread: thread.clone(),
thread_messages: Vec::new(), thread_messages: Vec::new(),
rendered_messages_by_id: HashMap::default(),
thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.view().downgrade(); let this = cx.view().downgrade();
move |ix, cx: &mut WindowContext| { move |ix, cx: &mut WindowContext| {
@ -104,6 +114,9 @@ impl AssistantPanel {
self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx)); self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
self.thread = thread; self.thread = thread;
self.thread_messages.clear();
self.thread_list_state.reset(0);
self.rendered_messages_by_id.clear();
self._subscriptions = subscriptions; self._subscriptions = subscriptions;
self.message_editor.focus_handle(cx).focus(cx); self.message_editor.focus_handle(cx).focus(cx);
@ -120,10 +133,61 @@ impl AssistantPanel {
self.last_error = Some(error.clone()); self.last_error = Some(error.clone());
} }
ThreadEvent::StreamedCompletion => {} ThreadEvent::StreamedCompletion => {}
ThreadEvent::StreamedAssistantText(message_id, text) => {
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
markdown.update(cx, |markdown, cx| {
markdown.append(text, cx);
});
}
}
ThreadEvent::MessageAdded(message_id) => { ThreadEvent::MessageAdded(message_id) => {
let old_len = self.thread_messages.len(); let old_len = self.thread_messages.len();
self.thread_messages.push(*message_id); self.thread_messages.push(*message_id);
self.thread_list_state.splice(old_len..old_len, 1); self.thread_list_state.splice(old_len..old_len, 1);
if let Some(message_text) = self
.thread
.read(cx)
.message(*message_id)
.map(|message| message.text.clone())
{
let theme_settings = ThemeSettings::get_global(cx);
let mut text_style = cx.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_size: Some(TextSize::Default.rems(cx).into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
let markdown_style = MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block: StyleRefinement {
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_size: Some(theme_settings.buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
..Default::default()
};
let markdown = cx.new_view(|cx| {
Markdown::new(
message_text,
markdown_style,
Some(self.language_registry.clone()),
None,
cx,
)
});
self.rendered_messages_by_id.insert(*message_id, markdown);
}
cx.notify(); cx.notify();
} }
ThreadEvent::UsePendingTools => { ThreadEvent::UsePendingTools => {
@ -323,6 +387,10 @@ impl AssistantPanel {
return Empty.into_any(); return Empty.into_any();
}; };
let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any();
};
let (role_icon, role_name) = match message.role { let (role_icon, role_name) = match message.role {
Role::User => (IconName::Person, "You"), Role::User => (IconName::Person, "You"),
Role::Assistant => (IconName::ZedAssistant, "Assistant"), Role::Assistant => (IconName::ZedAssistant, "Assistant"),
@ -350,7 +418,7 @@ impl AssistantPanel {
.child(Label::new(role_name).size(LabelSize::Small)), .child(Label::new(role_name).size(LabelSize::Small)),
), ),
) )
.child(v_flex().p_1p5().child(Label::new(message.text.clone()))), .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
) )
.into_any() .into_any()
} }

View file

@ -167,6 +167,10 @@ impl Thread {
if let Some(last_message) = thread.messages.last_mut() { if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant { if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk); last_message.text.push_str(&chunk);
cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id,
chunk,
));
} }
} }
} }
@ -320,6 +324,7 @@ pub enum ThreadError {
pub enum ThreadEvent { pub enum ThreadEvent {
ShowError(ThreadError), ShowError(ThreadError),
StreamedCompletion, StreamedCompletion,
StreamedAssistantText(MessageId, String),
MessageAdded(MessageId), MessageAdded(MessageId),
UsePendingTools, UsePendingTools,
ToolFinished { ToolFinished {