Cycle message roles on click

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-06-06 18:45:08 +02:00
parent ac7178068f
commit ef7ec265c8
4 changed files with 106 additions and 51 deletions

View file

@ -34,6 +34,16 @@ enum Role {
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {

View file

@ -485,14 +485,16 @@ impl Assistant {
let messages = self
.messages
.iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: message.content.read(cx).text(),
name: None,
.filter_map(|message| {
Some(tiktoken_rs::ChatCompletionRequestMessage {
role: match self.messages_metadata.get(&message.excerpt_id)?.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: message.content.read(cx).text(),
name: None,
})
})
.collect::<Vec<_>>();
let model = self.model.clone();
@ -529,9 +531,11 @@ impl Assistant {
let messages = self
.messages
.iter()
.map(|message| RequestMessage {
role: message.role,
content: message.content.read(cx).text(),
.filter_map(|message| {
Some(RequestMessage {
role: self.messages_metadata.get(&message.excerpt_id)?.role,
content: message.content.read(cx).text(),
})
})
.collect();
let request = OpenAIRequest {
@ -621,6 +625,13 @@ impl Assistant {
}
}
fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext<Self>) {
if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) {
metadata.role.cycle();
cx.notify();
}
}
fn push_message(
&mut self,
role: Role,
@ -659,7 +670,6 @@ impl Assistant {
self.messages.push(Message {
excerpt_id,
role,
content: content.clone(),
});
self.messages_metadata.insert(
@ -681,9 +691,11 @@ impl Assistant {
.messages
.iter()
.take(2)
.map(|message| RequestMessage {
role: message.role,
content: message.content.read(cx).text(),
.filter_map(|message| {
Some(RequestMessage {
role: self.messages_metadata.get(&message.excerpt_id)?.role,
content: message.content.read(cx).text(),
})
})
.chain(Some(RequestMessage {
role: Role::User,
@ -753,27 +765,51 @@ impl AssistantEditor {
{
let assistant = assistant.clone();
move |_editor, params: editor::RenderExcerptHeaderParams, cx| {
enum Sender {}
enum ErrorTooltip {}
let theme = theme::current(cx);
let style = &theme.assistant;
if let Some(metadata) = assistant.read(cx).messages_metadata.get(&params.id)
let excerpt_id = params.id;
if let Some(metadata) = assistant
.read(cx)
.messages_metadata
.get(&excerpt_id)
.cloned()
{
let sender = match metadata.role {
Role::User => Label::new("You", style.user_sender.text.clone())
.contained()
.with_style(style.user_sender.container),
Role::Assistant => {
Label::new("Assistant", style.assistant_sender.text.clone())
.contained()
.with_style(style.assistant_sender.container)
let sender = MouseEventHandler::<Sender, _>::new(
params.id.into(),
cx,
|state, _| match metadata.role {
Role::User => {
let style = style.user_sender.style_for(state, false);
Label::new("You", style.text.clone())
.contained()
.with_style(style.container)
}
Role::Assistant => {
let style = style.assistant_sender.style_for(state, false);
Label::new("Assistant", style.text.clone())
.contained()
.with_style(style.container)
}
Role::System => {
let style = style.system_sender.style_for(state, false);
Label::new("System", style.text.clone())
.contained()
.with_style(style.container)
}
},
)
.with_cursor_style(CursorStyle::PointingHand)
.on_down(MouseButton::Left, {
let assistant = assistant.clone();
move |_, _, cx| {
assistant.update(cx, |assistant, cx| {
assistant.cycle_message_role(excerpt_id, cx)
})
}
Role::System => {
Label::new("System", style.assistant_sender.text.clone())
.contained()
.with_style(style.assistant_sender.container)
}
};
});
Flex::row()
.with_child(sender.aligned())
@ -786,7 +822,7 @@ impl AssistantEditor {
.with_style(style.sent_at.container)
.aligned(),
)
.with_children(metadata.error.clone().map(|error| {
.with_children(metadata.error.map(|error| {
Svg::new("icons/circle_x_mark_12.svg")
.with_color(style.error_icon.color)
.constrained()
@ -833,21 +869,22 @@ impl AssistantEditor {
self.assistant.update(cx, |assistant, cx| {
let editor = self.editor.read(cx);
let newest_selection = editor.selections.newest_anchor();
let role = if newest_selection.head() == Anchor::min() {
assistant.messages.first().map(|message| message.role)
let excerpt_id = if newest_selection.head() == Anchor::min() {
assistant.messages.first().map(|message| message.excerpt_id)
} else if newest_selection.head() == Anchor::max() {
assistant.messages.last().map(|message| message.role)
assistant.messages.last().map(|message| message.excerpt_id)
} else {
assistant
.messages_metadata
.get(&newest_selection.head().excerpt_id())
.map(|message| message.role)
Some(newest_selection.head().excerpt_id())
};
if role.map_or(false, |role| role == Role::Assistant) {
assistant.push_message(Role::User, cx);
} else {
assistant.assist(cx);
if let Some(excerpt_id) = excerpt_id {
if let Some(metadata) = assistant.messages_metadata.get(&excerpt_id) {
if metadata.role == Role::User {
assistant.assist(cx);
} else {
assistant.push_message(Role::User, cx);
}
}
}
});
}
@ -967,12 +1004,17 @@ impl AssistantEditor {
let range = cmp::max(message_range.start, selection.range().start)
..cmp::min(message_range.end, selection.range().end);
if !range.is_empty() {
spanned_messages += 1;
write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
for chunk in assistant.buffer.read(cx).snapshot(cx).text_for_range(range) {
copied_text.push_str(&chunk);
if let Some(metadata) = assistant.messages_metadata.get(&message.excerpt_id)
{
spanned_messages += 1;
write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap();
for chunk in
assistant.buffer.read(cx).snapshot(cx).text_for_range(range)
{
copied_text.push_str(&chunk);
}
copied_text.push('\n');
}
copied_text.push('\n');
}
}
@ -1090,11 +1132,10 @@ impl Item for AssistantEditor {
#[derive(Debug)]
struct Message {
excerpt_id: ExcerptId,
role: Role,
content: ModelHandle<Buffer>,
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,

View file

@ -974,8 +974,9 @@ pub struct AssistantStyle {
pub container: ContainerStyle,
pub header: ContainerStyle,
pub sent_at: ContainedText,
pub user_sender: ContainedText,
pub assistant_sender: ContainedText,
pub user_sender: Interactive<ContainedText>,
pub assistant_sender: Interactive<ContainedText>,
pub system_sender: Interactive<ContainedText>,
pub model_info_container: ContainerStyle,
pub model: Interactive<ContainedText>,
pub remaining_tokens: ContainedText,

View file

@ -20,6 +20,9 @@ export default function assistant(colorScheme: ColorScheme) {
assistantSender: {
...text(layer, "sans", "accent", { size: "sm", weight: "bold" }),
},
systemSender: {
...text(layer, "sans", "variant", { size: "sm", weight: "bold" }),
},
sentAt: {
margin: { top: 2, left: 8 },
...text(layer, "sans", "default", { size: "2xs" }),