assistant2: Factor out ActiveThread view (#21555)

This PR factors a new `ActiveThread` view out of the `AssistantPanel` to
group together the state that pertains solely to the active view.

There was a bunch of related state on the `AssistantPanel` pertaining to
the active thread that needed to be initialized/reset together and it
makes for a clearer narrative is this state is encapsulated in its own
view.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-12-04 16:39:39 -05:00 committed by GitHub
parent 55ecb3c51b
commit a30ea2fc68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 396 additions and 291 deletions

View file

@ -0,0 +1,237 @@
use std::sync::Arc;
use assistant_tool::ToolWorkingSet;
use collections::HashMap;
use gpui::{
list, AnyElement, Empty, ListAlignment, ListState, Model, StyleRefinement, Subscription,
TextStyleRefinement, View, WeakView,
};
use language::LanguageRegistry;
use language_model::Role;
use markdown::{Markdown, MarkdownStyle};
use settings::Settings as _;
use theme::ThemeSettings;
use ui::prelude::*;
use workspace::Workspace;
use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
pub struct ActiveThread {
workspace: WeakView<Workspace>,
language_registry: Arc<LanguageRegistry>,
tools: Arc<ToolWorkingSet>,
thread: Model<Thread>,
messages: Vec<MessageId>,
list_state: ListState,
rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
last_error: Option<ThreadError>,
_subscriptions: Vec<Subscription>,
}
impl ActiveThread {
pub fn new(
thread: Model<Thread>,
workspace: WeakView<Workspace>,
language_registry: Arc<LanguageRegistry>,
tools: Arc<ToolWorkingSet>,
cx: &mut ViewContext<Self>,
) -> Self {
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
];
let mut this = Self {
workspace,
language_registry,
tools,
thread: thread.clone(),
messages: Vec::new(),
rendered_messages_by_id: HashMap::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.view().downgrade();
move |ix, cx: &mut WindowContext| {
this.update(cx, |this, cx| this.render_message(ix, cx))
.unwrap()
}
}),
last_error: None,
_subscriptions: subscriptions,
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
this.push_message(&message.id, message.text.clone(), cx);
}
this
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn last_error(&self) -> Option<ThreadError> {
self.last_error.clone()
}
pub fn clear_last_error(&mut self) {
self.last_error.take();
}
fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
let old_len = self.messages.len();
self.messages.push(*id);
self.list_state.splice(old_len..old_len, 1);
let theme_settings = ThemeSettings::get_global(cx);
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = theme_settings.buffer_font_size;
let mut text_style = cx.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_size: Some(ui_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
let markdown_style = MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block: StyleRefinement {
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_size: Some(ui_font_size.into()),
background_color: Some(cx.theme().colors().editor_background),
..Default::default()
},
..Default::default()
};
let markdown = cx.new_view(|cx| {
Markdown::new(
text,
markdown_style,
Some(self.language_registry.clone()),
None,
cx,
)
});
self.rendered_messages_by_id.insert(*id, markdown);
}
fn handle_thread_event(
&mut self,
_: Model<Thread>,
event: &ThreadEvent,
cx: &mut ViewContext<Self>,
) {
match event {
ThreadEvent::ShowError(error) => {
self.last_error = Some(error.clone());
}
ThreadEvent::StreamedCompletion => {}
ThreadEvent::StreamedAssistantText(message_id, text) => {
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
markdown.update(cx, |markdown, cx| {
markdown.append(text, cx);
});
}
}
ThreadEvent::MessageAdded(message_id) => {
if let Some(message_text) = self
.thread
.read(cx)
.message(*message_id)
.map(|message| message.text.clone())
{
self.push_message(message_id, message_text, cx);
}
cx.notify();
}
ThreadEvent::UsePendingTools => {
let pending_tool_uses = self
.thread
.read(cx)
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.workspace.clone(), cx);
self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(
tool_use.assistant_message_id,
tool_use.id.clone(),
task,
cx,
);
});
}
}
}
ThreadEvent::ToolFinished { .. } => {}
}
}
fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
let message_id = self.messages[ix];
let Some(message) = self.thread.read(cx).message(message_id) else {
return Empty.into_any();
};
let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any();
};
let (role_icon, role_name) = match message.role {
Role::User => (IconName::Person, "You"),
Role::Assistant => (IconName::ZedAssistant, "Assistant"),
Role::System => (IconName::Settings, "System"),
};
div()
.id(("message-container", ix))
.p_2()
.child(
v_flex()
.border_1()
.border_color(cx.theme().colors().border_variant)
.rounded_md()
.child(
h_flex()
.justify_between()
.p_1p5()
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.child(
h_flex()
.gap_2()
.child(Icon::new(role_icon).size(IconSize::Small))
.child(Label::new(role_name).size(LabelSize::Small)),
),
)
.child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
)
.into_any()
}
}
impl Render for ActiveThread {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
list(self.list_state.clone()).flex_1()
}
}

View file

@ -1,3 +1,4 @@
mod active_thread;
mod assistant_panel; mod assistant_panel;
mod message_editor; mod message_editor;
mod thread; mod thread;

View file

@ -3,25 +3,21 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use assistant_tool::ToolWorkingSet; use assistant_tool::ToolWorkingSet;
use client::zed_urls; use client::zed_urls;
use collections::HashMap;
use gpui::{ use gpui::{
list, prelude::*, px, svg, Action, AnyElement, AppContext, AsyncWindowContext, Empty, prelude::*, px, svg, Action, AnyElement, AppContext, AsyncWindowContext, EventEmitter,
EventEmitter, FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels, FocusHandle, FocusableView, FontWeight, Model, Pixels, Task, View, ViewContext, WeakView,
StyleRefinement, Subscription, Task, TextStyleRefinement, View, ViewContext, WeakView,
WindowContext, WindowContext,
}; };
use language::LanguageRegistry; use language::LanguageRegistry;
use language_model::{LanguageModelRegistry, Role}; use language_model::LanguageModelRegistry;
use language_model_selector::LanguageModelSelector; use language_model_selector::LanguageModelSelector;
use markdown::{Markdown, MarkdownStyle};
use settings::Settings;
use theme::ThemeSettings;
use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, KeyBinding, ListItem, Tab, Tooltip}; use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, KeyBinding, ListItem, Tab, Tooltip};
use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::Workspace; use workspace::Workspace;
use crate::active_thread::ActiveThread;
use crate::message_editor::MessageEditor; use crate::message_editor::MessageEditor;
use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent, ThreadId}; use crate::thread::{Thread, ThreadError, ThreadId};
use crate::thread_store::ThreadStore; use crate::thread_store::ThreadStore;
use crate::{NewThread, OpenHistory, ToggleFocus, ToggleModelSelector}; use crate::{NewThread, OpenHistory, ToggleFocus, ToggleModelSelector};
@ -39,16 +35,10 @@ pub fn init(cx: &mut AppContext) {
pub struct AssistantPanel { pub struct AssistantPanel {
workspace: WeakView<Workspace>, workspace: WeakView<Workspace>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
#[allow(unused)]
thread_store: Model<ThreadStore>, thread_store: Model<ThreadStore>,
thread: Model<Thread>, thread: Option<View<ActiveThread>>,
thread_messages: Vec<MessageId>,
rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
thread_list_state: ListState,
message_editor: View<MessageEditor>, message_editor: View<MessageEditor>,
tools: Arc<ToolWorkingSet>, tools: Arc<ToolWorkingSet>,
last_error: Option<ThreadError>,
_subscriptions: Vec<Subscription>,
} }
impl AssistantPanel { impl AssistantPanel {
@ -78,29 +68,14 @@ impl AssistantPanel {
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> Self { ) -> Self {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx)); let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
];
Self { Self {
workspace: workspace.weak_handle(), workspace: workspace.weak_handle(),
language_registry: workspace.project().read(cx).languages().clone(), language_registry: workspace.project().read(cx).languages().clone(),
thread_store, thread_store,
thread: thread.clone(), thread: None,
thread_messages: Vec::new(),
rendered_messages_by_id: HashMap::default(),
thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.view().downgrade();
move |ix, cx: &mut WindowContext| {
this.update(cx, |this, cx| this.render_message(ix, cx))
.unwrap()
}
}),
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)), message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
tools, tools,
last_error: None,
_subscriptions: subscriptions,
} }
} }
@ -108,7 +83,18 @@ impl AssistantPanel {
let thread = self let thread = self
.thread_store .thread_store
.update(cx, |this, cx| this.create_thread(cx)); .update(cx, |this, cx| this.create_thread(cx));
self.reset_thread(thread, cx);
self.thread = Some(cx.new_view(|cx| {
ActiveThread::new(
thread.clone(),
self.workspace.clone(),
self.language_registry.clone(),
self.tools.clone(),
cx,
)
}));
self.message_editor = cx.new_view(|cx| MessageEditor::new(thread, cx));
self.message_editor.focus_handle(cx).focus(cx);
} }
fn open_thread(&mut self, thread_id: &ThreadId, cx: &mut ViewContext<Self>) { fn open_thread(&mut self, thread_id: &ThreadId, cx: &mut ViewContext<Self>) {
@ -118,136 +104,18 @@ impl AssistantPanel {
else { else {
return; return;
}; };
self.reset_thread(thread.clone(), cx);
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() { self.thread = Some(cx.new_view(|cx| {
self.push_message(&message.id, message.text.clone(), cx); ActiveThread::new(
} thread.clone(),
} self.workspace.clone(),
self.language_registry.clone(),
fn reset_thread(&mut self, thread: Model<Thread>, cx: &mut ViewContext<Self>) { self.tools.clone(),
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
];
self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
self.thread = thread;
self.thread_messages.clear();
self.thread_list_state.reset(0);
self.rendered_messages_by_id.clear();
self._subscriptions = subscriptions;
self.message_editor.focus_handle(cx).focus(cx);
}
fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
let old_len = self.thread_messages.len();
self.thread_messages.push(*id);
self.thread_list_state.splice(old_len..old_len, 1);
let theme_settings = ThemeSettings::get_global(cx);
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = theme_settings.buffer_font_size;
let mut text_style = cx.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_size: Some(ui_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
let markdown_style = MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block: StyleRefinement {
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_size: Some(ui_font_size.into()),
background_color: Some(cx.theme().colors().editor_background),
..Default::default()
},
..Default::default()
};
let markdown = cx.new_view(|cx| {
Markdown::new(
text,
markdown_style,
Some(self.language_registry.clone()),
None,
cx, cx,
) )
}); }));
self.rendered_messages_by_id.insert(*id, markdown); self.message_editor = cx.new_view(|cx| MessageEditor::new(thread, cx));
} self.message_editor.focus_handle(cx).focus(cx);
fn handle_thread_event(
&mut self,
_: Model<Thread>,
event: &ThreadEvent,
cx: &mut ViewContext<Self>,
) {
match event {
ThreadEvent::ShowError(error) => {
self.last_error = Some(error.clone());
}
ThreadEvent::StreamedCompletion => {}
ThreadEvent::StreamedAssistantText(message_id, text) => {
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
markdown.update(cx, |markdown, cx| {
markdown.append(text, cx);
});
}
}
ThreadEvent::MessageAdded(message_id) => {
if let Some(message_text) = self
.thread
.read(cx)
.message(*message_id)
.map(|message| message.text.clone())
{
self.push_message(message_id, message_text, cx);
}
cx.notify();
}
ThreadEvent::UsePendingTools => {
let pending_tool_uses = self
.thread
.read(cx)
.pending_tool_uses()
.into_iter()
.filter(|tool_use| tool_use.status.is_idle())
.cloned()
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.workspace.clone(), cx);
self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(
tool_use.assistant_message_id,
tool_use.id.clone(),
task,
cx,
);
});
}
}
}
ThreadEvent::ToolFinished { .. } => {}
}
} }
} }
@ -422,13 +290,24 @@ impl AssistantPanel {
) )
} }
fn render_message_list(&self, cx: &mut ViewContext<Self>) -> AnyElement { fn render_active_thread_or_empty_state(&self, cx: &mut ViewContext<Self>) -> AnyElement {
if self.thread_messages.is_empty() { let Some(thread) = self.thread.as_ref() else {
return self.render_thread_empty_state(cx).into_any_element();
};
if thread.read(cx).is_empty() {
return self.render_thread_empty_state(cx).into_any_element();
}
thread.clone().into_any()
}
fn render_thread_empty_state(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let recent_threads = self let recent_threads = self
.thread_store .thread_store
.update(cx, |this, cx| this.recent_threads(3, cx)); .update(cx, |this, cx| this.recent_threads(3, cx));
return v_flex() v_flex()
.gap_2() .gap_2()
.mx_auto() .mx_auto()
.child( .child(
@ -510,52 +389,6 @@ impl AssistantPanel {
}), }),
), ),
) )
.into_any();
}
list(self.thread_list_state.clone()).flex_1().into_any()
}
fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
let message_id = self.thread_messages[ix];
let Some(message) = self.thread.read(cx).message(message_id) else {
return Empty.into_any();
};
let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any();
};
let (role_icon, role_name) = match message.role {
Role::User => (IconName::Person, "You"),
Role::Assistant => (IconName::ZedAssistant, "Assistant"),
Role::System => (IconName::Settings, "System"),
};
div()
.id(("message-container", ix))
.p_2()
.child(
v_flex()
.border_1()
.border_color(cx.theme().colors().border_variant)
.rounded_md()
.child(
h_flex()
.justify_between()
.p_1p5()
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.child(
h_flex()
.gap_2()
.child(Icon::new(role_icon).size(IconSize::Small))
.child(Label::new(role_name).size(LabelSize::Small)),
),
)
.child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
)
.into_any()
} }
fn render_past_thread( fn render_past_thread(
@ -584,7 +417,7 @@ impl AssistantPanel {
} }
fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> { fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
let last_error = self.last_error.as_ref()?; let last_error = self.thread.as_ref()?.read(cx).last_error()?;
Some( Some(
div() div()
@ -602,7 +435,7 @@ impl AssistantPanel {
self.render_max_monthly_spend_reached_error(cx) self.render_max_monthly_spend_reached_error(cx)
} }
ThreadError::Message(error_message) => { ThreadError::Message(error_message) => {
self.render_error_message(error_message, cx) self.render_error_message(&error_message, cx)
} }
}) })
.into_any(), .into_any(),
@ -634,14 +467,24 @@ impl AssistantPanel {
.mt_1() .mt_1()
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener( .child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
|this, _, cx| { |this, _, cx| {
this.last_error = None; if let Some(thread) = this.thread.as_ref() {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
}
cx.open_url(&zed_urls::account_url(cx)); cx.open_url(&zed_urls::account_url(cx));
cx.notify(); cx.notify();
}, },
))) )))
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener( .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, cx| { |this, _, cx| {
this.last_error = None; if let Some(thread) = this.thread.as_ref() {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
}
cx.notify(); cx.notify();
}, },
))), ))),
@ -675,7 +518,12 @@ impl AssistantPanel {
.child( .child(
Button::new("subscribe", "Update Monthly Spend Limit").on_click( Button::new("subscribe", "Update Monthly Spend Limit").on_click(
cx.listener(|this, _, cx| { cx.listener(|this, _, cx| {
this.last_error = None; if let Some(thread) = this.thread.as_ref() {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
}
cx.open_url(&zed_urls::account_url(cx)); cx.open_url(&zed_urls::account_url(cx));
cx.notify(); cx.notify();
}), }),
@ -683,7 +531,12 @@ impl AssistantPanel {
) )
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener( .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, cx| { |this, _, cx| {
this.last_error = None; if let Some(thread) = this.thread.as_ref() {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
}
cx.notify(); cx.notify();
}, },
))), ))),
@ -721,7 +574,12 @@ impl AssistantPanel {
.mt_1() .mt_1()
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener( .child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, cx| { |this, _, cx| {
this.last_error = None; if let Some(thread) = this.thread.as_ref() {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
}
cx.notify(); cx.notify();
}, },
))), ))),
@ -743,7 +601,7 @@ impl Render for AssistantPanel {
println!("Open History"); println!("Open History");
})) }))
.child(self.render_toolbar(cx)) .child(self.render_toolbar(cx))
.child(self.render_message_list(cx)) .child(self.render_active_thread_or_empty_state(cx))
.child( .child(
h_flex() h_flex()
.border_t_1() .border_t_1()

View file

@ -85,6 +85,10 @@ impl Thread {
&self.id &self.id
} }
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn message(&self, id: MessageId) -> Option<&Message> { pub fn message(&self, id: MessageId) -> Option<&Message> {
self.messages.iter().find(|message| message.id == id) self.messages.iter().find(|message| message.id == id)
} }

View file

@ -52,8 +52,13 @@ impl ThreadStore {
}) })
} }
pub fn recent_threads(&self, limit: usize, _cx: &ModelContext<Self>) -> Vec<Model<Thread>> { pub fn recent_threads(&self, limit: usize, cx: &ModelContext<Self>) -> Vec<Model<Thread>> {
self.threads.iter().take(limit).cloned().collect() self.threads
.iter()
.filter(|thread| !thread.read(cx).is_empty())
.take(limit)
.cloned()
.collect()
} }
pub fn create_thread(&mut self, cx: &mut ModelContext<Self>) -> Model<Thread> { pub fn create_thread(&mut self, cx: &mut ModelContext<Self>) -> Model<Thread> {