Cycle message roles on click

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-06-06 18:45:08 +02:00
parent ac7178068f
commit ef7ec265c8
4 changed files with 106 additions and 51 deletions

View file

@ -34,6 +34,16 @@ enum Role {
System, System,
} }
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role { impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {

View file

@ -485,8 +485,9 @@ impl Assistant {
let messages = self let messages = self
.messages .messages
.iter() .iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage { .filter_map(|message| {
role: match message.role { Some(tiktoken_rs::ChatCompletionRequestMessage {
role: match self.messages_metadata.get(&message.excerpt_id)?.role {
Role::User => "user".into(), Role::User => "user".into(),
Role::Assistant => "assistant".into(), Role::Assistant => "assistant".into(),
Role::System => "system".into(), Role::System => "system".into(),
@ -494,6 +495,7 @@ impl Assistant {
content: message.content.read(cx).text(), content: message.content.read(cx).text(),
name: None, name: None,
}) })
})
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let model = self.model.clone(); let model = self.model.clone();
self.pending_token_count = cx.spawn(|this, mut cx| { self.pending_token_count = cx.spawn(|this, mut cx| {
@ -529,10 +531,12 @@ impl Assistant {
let messages = self let messages = self
.messages .messages
.iter() .iter()
.map(|message| RequestMessage { .filter_map(|message| {
role: message.role, Some(RequestMessage {
role: self.messages_metadata.get(&message.excerpt_id)?.role,
content: message.content.read(cx).text(), content: message.content.read(cx).text(),
}) })
})
.collect(); .collect();
let request = OpenAIRequest { let request = OpenAIRequest {
model: self.model.clone(), model: self.model.clone(),
@ -621,6 +625,13 @@ impl Assistant {
} }
} }
fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext<Self>) {
if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) {
metadata.role.cycle();
cx.notify();
}
}
fn push_message( fn push_message(
&mut self, &mut self,
role: Role, role: Role,
@ -659,7 +670,6 @@ impl Assistant {
self.messages.push(Message { self.messages.push(Message {
excerpt_id, excerpt_id,
role,
content: content.clone(), content: content.clone(),
}); });
self.messages_metadata.insert( self.messages_metadata.insert(
@ -681,10 +691,12 @@ impl Assistant {
.messages .messages
.iter() .iter()
.take(2) .take(2)
.map(|message| RequestMessage { .filter_map(|message| {
role: message.role, Some(RequestMessage {
role: self.messages_metadata.get(&message.excerpt_id)?.role,
content: message.content.read(cx).text(), content: message.content.read(cx).text(),
}) })
})
.chain(Some(RequestMessage { .chain(Some(RequestMessage {
role: Role::User, role: Role::User,
content: content:
@ -753,27 +765,51 @@ impl AssistantEditor {
{ {
let assistant = assistant.clone(); let assistant = assistant.clone();
move |_editor, params: editor::RenderExcerptHeaderParams, cx| { move |_editor, params: editor::RenderExcerptHeaderParams, cx| {
enum Sender {}
enum ErrorTooltip {} enum ErrorTooltip {}
let theme = theme::current(cx); let theme = theme::current(cx);
let style = &theme.assistant; let style = &theme.assistant;
if let Some(metadata) = assistant.read(cx).messages_metadata.get(&params.id) let excerpt_id = params.id;
if let Some(metadata) = assistant
.read(cx)
.messages_metadata
.get(&excerpt_id)
.cloned()
{ {
let sender = match metadata.role { let sender = MouseEventHandler::<Sender, _>::new(
Role::User => Label::new("You", style.user_sender.text.clone()) params.id.into(),
cx,
|state, _| match metadata.role {
Role::User => {
let style = style.user_sender.style_for(state, false);
Label::new("You", style.text.clone())
.contained() .contained()
.with_style(style.user_sender.container), .with_style(style.container)
}
Role::Assistant => { Role::Assistant => {
Label::new("Assistant", style.assistant_sender.text.clone()) let style = style.assistant_sender.style_for(state, false);
Label::new("Assistant", style.text.clone())
.contained() .contained()
.with_style(style.assistant_sender.container) .with_style(style.container)
} }
Role::System => { Role::System => {
Label::new("System", style.assistant_sender.text.clone()) let style = style.system_sender.style_for(state, false);
Label::new("System", style.text.clone())
.contained() .contained()
.with_style(style.assistant_sender.container) .with_style(style.container)
} }
}; },
)
.with_cursor_style(CursorStyle::PointingHand)
.on_down(MouseButton::Left, {
let assistant = assistant.clone();
move |_, _, cx| {
assistant.update(cx, |assistant, cx| {
assistant.cycle_message_role(excerpt_id, cx)
})
}
});
Flex::row() Flex::row()
.with_child(sender.aligned()) .with_child(sender.aligned())
@ -786,7 +822,7 @@ impl AssistantEditor {
.with_style(style.sent_at.container) .with_style(style.sent_at.container)
.aligned(), .aligned(),
) )
.with_children(metadata.error.clone().map(|error| { .with_children(metadata.error.map(|error| {
Svg::new("icons/circle_x_mark_12.svg") Svg::new("icons/circle_x_mark_12.svg")
.with_color(style.error_icon.color) .with_color(style.error_icon.color)
.constrained() .constrained()
@ -833,21 +869,22 @@ impl AssistantEditor {
self.assistant.update(cx, |assistant, cx| { self.assistant.update(cx, |assistant, cx| {
let editor = self.editor.read(cx); let editor = self.editor.read(cx);
let newest_selection = editor.selections.newest_anchor(); let newest_selection = editor.selections.newest_anchor();
let role = if newest_selection.head() == Anchor::min() { let excerpt_id = if newest_selection.head() == Anchor::min() {
assistant.messages.first().map(|message| message.role) assistant.messages.first().map(|message| message.excerpt_id)
} else if newest_selection.head() == Anchor::max() { } else if newest_selection.head() == Anchor::max() {
assistant.messages.last().map(|message| message.role) assistant.messages.last().map(|message| message.excerpt_id)
} else { } else {
assistant Some(newest_selection.head().excerpt_id())
.messages_metadata
.get(&newest_selection.head().excerpt_id())
.map(|message| message.role)
}; };
if role.map_or(false, |role| role == Role::Assistant) { if let Some(excerpt_id) = excerpt_id {
assistant.push_message(Role::User, cx); if let Some(metadata) = assistant.messages_metadata.get(&excerpt_id) {
} else { if metadata.role == Role::User {
assistant.assist(cx); assistant.assist(cx);
} else {
assistant.push_message(Role::User, cx);
}
}
} }
}); });
} }
@ -967,14 +1004,19 @@ impl AssistantEditor {
let range = cmp::max(message_range.start, selection.range().start) let range = cmp::max(message_range.start, selection.range().start)
..cmp::min(message_range.end, selection.range().end); ..cmp::min(message_range.end, selection.range().end);
if !range.is_empty() { if !range.is_empty() {
if let Some(metadata) = assistant.messages_metadata.get(&message.excerpt_id)
{
spanned_messages += 1; spanned_messages += 1;
write!(&mut copied_text, "## {}\n\n", message.role).unwrap(); write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap();
for chunk in assistant.buffer.read(cx).snapshot(cx).text_for_range(range) { for chunk in
assistant.buffer.read(cx).snapshot(cx).text_for_range(range)
{
copied_text.push_str(&chunk); copied_text.push_str(&chunk);
} }
copied_text.push('\n'); copied_text.push('\n');
} }
} }
}
offset = message_range.end; offset = message_range.end;
} }
@ -1090,11 +1132,10 @@ impl Item for AssistantEditor {
#[derive(Debug)] #[derive(Debug)]
struct Message { struct Message {
excerpt_id: ExcerptId, excerpt_id: ExcerptId,
role: Role,
content: ModelHandle<Buffer>, content: ModelHandle<Buffer>,
} }
#[derive(Debug)] #[derive(Clone, Debug)]
struct MessageMetadata { struct MessageMetadata {
role: Role, role: Role,
sent_at: DateTime<Local>, sent_at: DateTime<Local>,

View file

@ -974,8 +974,9 @@ pub struct AssistantStyle {
pub container: ContainerStyle, pub container: ContainerStyle,
pub header: ContainerStyle, pub header: ContainerStyle,
pub sent_at: ContainedText, pub sent_at: ContainedText,
pub user_sender: ContainedText, pub user_sender: Interactive<ContainedText>,
pub assistant_sender: ContainedText, pub assistant_sender: Interactive<ContainedText>,
pub system_sender: Interactive<ContainedText>,
pub model_info_container: ContainerStyle, pub model_info_container: ContainerStyle,
pub model: Interactive<ContainedText>, pub model: Interactive<ContainedText>,
pub remaining_tokens: ContainedText, pub remaining_tokens: ContainedText,

View file

@ -20,6 +20,9 @@ export default function assistant(colorScheme: ColorScheme) {
assistantSender: { assistantSender: {
...text(layer, "sans", "accent", { size: "sm", weight: "bold" }), ...text(layer, "sans", "accent", { size: "sm", weight: "bold" }),
}, },
systemSender: {
...text(layer, "sans", "variant", { size: "sm", weight: "bold" }),
},
sentAt: { sentAt: {
margin: { top: 2, left: 8 }, margin: { top: 2, left: 8 },
...text(layer, "sans", "default", { size: "2xs" }), ...text(layer, "sans", "default", { size: "2xs" }),