diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 6f26f00c52..40224b3229 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -34,6 +34,16 @@ enum Role { System, } +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + impl Display for Role { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index a61ecf202d..4a8319015f 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -485,14 +485,16 @@ impl Assistant { let messages = self .messages .iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: message.content.read(cx).text(), - name: None, + .filter_map(|message| { + Some(tiktoken_rs::ChatCompletionRequestMessage { + role: match self.messages_metadata.get(&message.excerpt_id)?.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: message.content.read(cx).text(), + name: None, + }) }) .collect::>(); let model = self.model.clone(); @@ -529,9 +531,11 @@ impl Assistant { let messages = self .messages .iter() - .map(|message| RequestMessage { - role: message.role, - content: message.content.read(cx).text(), + .filter_map(|message| { + Some(RequestMessage { + role: self.messages_metadata.get(&message.excerpt_id)?.role, + content: message.content.read(cx).text(), + }) }) .collect(); let request = OpenAIRequest { @@ -621,6 +625,13 @@ impl Assistant { } } + fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext) { + if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) { + metadata.role.cycle(); + cx.notify(); + } + } + fn push_message( &mut self, role: Role, @@ -659,7 +670,6 @@ impl Assistant { self.messages.push(Message { excerpt_id, - role, content: content.clone(), }); self.messages_metadata.insert( @@ -681,9 +691,11 @@ impl Assistant { .messages .iter() .take(2) - .map(|message| RequestMessage { - role: message.role, - content: message.content.read(cx).text(), + .filter_map(|message| { + Some(RequestMessage { + role: self.messages_metadata.get(&message.excerpt_id)?.role, + content: message.content.read(cx).text(), + }) }) .chain(Some(RequestMessage { role: Role::User, @@ -753,27 +765,51 @@ impl AssistantEditor { { let assistant = assistant.clone(); move |_editor, params: editor::RenderExcerptHeaderParams, cx| { + enum Sender {} enum ErrorTooltip {} let theme = theme::current(cx); let style = &theme.assistant; - if let Some(metadata) = assistant.read(cx).messages_metadata.get(¶ms.id) + let excerpt_id = params.id; + if let Some(metadata) = assistant + .read(cx) + .messages_metadata + .get(&excerpt_id) + .cloned() { - let sender = match metadata.role { - Role::User => Label::new("You", style.user_sender.text.clone()) - .contained() - .with_style(style.user_sender.container), - Role::Assistant => { - Label::new("Assistant", style.assistant_sender.text.clone()) - .contained() - .with_style(style.assistant_sender.container) + let sender = MouseEventHandler::::new( + params.id.into(), + cx, + |state, _| match metadata.role { + Role::User => { + let style = style.user_sender.style_for(state, false); + Label::new("You", style.text.clone()) + .contained() + .with_style(style.container) + } + Role::Assistant => { + let style = style.assistant_sender.style_for(state, false); + Label::new("Assistant", style.text.clone()) + .contained() + .with_style(style.container) + } + Role::System => { + let style = style.system_sender.style_for(state, false); + Label::new("System", style.text.clone()) + .contained() + .with_style(style.container) + } + }, + ) + .with_cursor_style(CursorStyle::PointingHand) + .on_down(MouseButton::Left, { + let assistant = assistant.clone(); + move |_, _, cx| { + assistant.update(cx, |assistant, cx| { + assistant.cycle_message_role(excerpt_id, cx) + }) } - Role::System => { - Label::new("System", style.assistant_sender.text.clone()) - .contained() - .with_style(style.assistant_sender.container) - } - }; + }); Flex::row() .with_child(sender.aligned()) @@ -786,7 +822,7 @@ impl AssistantEditor { .with_style(style.sent_at.container) .aligned(), ) - .with_children(metadata.error.clone().map(|error| { + .with_children(metadata.error.map(|error| { Svg::new("icons/circle_x_mark_12.svg") .with_color(style.error_icon.color) .constrained() @@ -833,21 +869,22 @@ impl AssistantEditor { self.assistant.update(cx, |assistant, cx| { let editor = self.editor.read(cx); let newest_selection = editor.selections.newest_anchor(); - let role = if newest_selection.head() == Anchor::min() { - assistant.messages.first().map(|message| message.role) + let excerpt_id = if newest_selection.head() == Anchor::min() { + assistant.messages.first().map(|message| message.excerpt_id) } else if newest_selection.head() == Anchor::max() { - assistant.messages.last().map(|message| message.role) + assistant.messages.last().map(|message| message.excerpt_id) } else { - assistant - .messages_metadata - .get(&newest_selection.head().excerpt_id()) - .map(|message| message.role) + Some(newest_selection.head().excerpt_id()) }; - if role.map_or(false, |role| role == Role::Assistant) { - assistant.push_message(Role::User, cx); - } else { - assistant.assist(cx); + if let Some(excerpt_id) = excerpt_id { + if let Some(metadata) = assistant.messages_metadata.get(&excerpt_id) { + if metadata.role == Role::User { + assistant.assist(cx); + } else { + assistant.push_message(Role::User, cx); + } + } } }); } @@ -967,12 +1004,17 @@ impl AssistantEditor { let range = cmp::max(message_range.start, selection.range().start) ..cmp::min(message_range.end, selection.range().end); if !range.is_empty() { - spanned_messages += 1; - write!(&mut copied_text, "## {}\n\n", message.role).unwrap(); - for chunk in assistant.buffer.read(cx).snapshot(cx).text_for_range(range) { - copied_text.push_str(&chunk); + if let Some(metadata) = assistant.messages_metadata.get(&message.excerpt_id) + { + spanned_messages += 1; + write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap(); + for chunk in + assistant.buffer.read(cx).snapshot(cx).text_for_range(range) + { + copied_text.push_str(&chunk); + } + copied_text.push('\n'); } - copied_text.push('\n'); } } @@ -1090,11 +1132,10 @@ impl Item for AssistantEditor { #[derive(Debug)] struct Message { excerpt_id: ExcerptId, - role: Role, content: ModelHandle, } -#[derive(Debug)] +#[derive(Clone, Debug)] struct MessageMetadata { role: Role, sent_at: DateTime, diff --git a/crates/theme/src/theme.rs b/crates/theme/src/theme.rs index 132e37ad1c..f7df63ca09 100644 --- a/crates/theme/src/theme.rs +++ b/crates/theme/src/theme.rs @@ -974,8 +974,9 @@ pub struct AssistantStyle { pub container: ContainerStyle, pub header: ContainerStyle, pub sent_at: ContainedText, - pub user_sender: ContainedText, - pub assistant_sender: ContainedText, + pub user_sender: Interactive, + pub assistant_sender: Interactive, + pub system_sender: Interactive, pub model_info_container: ContainerStyle, pub model: Interactive, pub remaining_tokens: ContainedText, diff --git a/styles/src/styleTree/assistant.ts b/styles/src/styleTree/assistant.ts index 4314741fb0..5e33967b50 100644 --- a/styles/src/styleTree/assistant.ts +++ b/styles/src/styleTree/assistant.ts @@ -20,6 +20,9 @@ export default function assistant(colorScheme: ColorScheme) { assistantSender: { ...text(layer, "sans", "accent", { size: "sm", weight: "bold" }), }, + systemSender: { + ...text(layer, "sans", "variant", { size: "sm", weight: "bold" }), + }, sentAt: { margin: { top: 2, left: 8 }, ...text(layer, "sans", "default", { size: "2xs" }),