Cycle message roles on click
Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
ac7178068f
commit
ef7ec265c8
4 changed files with 106 additions and 51 deletions
|
@ -34,6 +34,16 @@ enum Role {
|
||||||
System,
|
System,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Role {
|
||||||
|
pub fn cycle(&mut self) {
|
||||||
|
*self = match self {
|
||||||
|
Role::User => Role::Assistant,
|
||||||
|
Role::Assistant => Role::System,
|
||||||
|
Role::System => Role::User,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Display for Role {
|
impl Display for Role {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
|
|
|
@ -485,14 +485,16 @@ impl Assistant {
|
||||||
let messages = self
|
let messages = self
|
||||||
.messages
|
.messages
|
||||||
.iter()
|
.iter()
|
||||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
.filter_map(|message| {
|
||||||
role: match message.role {
|
Some(tiktoken_rs::ChatCompletionRequestMessage {
|
||||||
Role::User => "user".into(),
|
role: match self.messages_metadata.get(&message.excerpt_id)?.role {
|
||||||
Role::Assistant => "assistant".into(),
|
Role::User => "user".into(),
|
||||||
Role::System => "system".into(),
|
Role::Assistant => "assistant".into(),
|
||||||
},
|
Role::System => "system".into(),
|
||||||
content: message.content.read(cx).text(),
|
},
|
||||||
name: None,
|
content: message.content.read(cx).text(),
|
||||||
|
name: None,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let model = self.model.clone();
|
let model = self.model.clone();
|
||||||
|
@ -529,9 +531,11 @@ impl Assistant {
|
||||||
let messages = self
|
let messages = self
|
||||||
.messages
|
.messages
|
||||||
.iter()
|
.iter()
|
||||||
.map(|message| RequestMessage {
|
.filter_map(|message| {
|
||||||
role: message.role,
|
Some(RequestMessage {
|
||||||
content: message.content.read(cx).text(),
|
role: self.messages_metadata.get(&message.excerpt_id)?.role,
|
||||||
|
content: message.content.read(cx).text(),
|
||||||
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let request = OpenAIRequest {
|
let request = OpenAIRequest {
|
||||||
|
@ -621,6 +625,13 @@ impl Assistant {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext<Self>) {
|
||||||
|
if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) {
|
||||||
|
metadata.role.cycle();
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn push_message(
|
fn push_message(
|
||||||
&mut self,
|
&mut self,
|
||||||
role: Role,
|
role: Role,
|
||||||
|
@ -659,7 +670,6 @@ impl Assistant {
|
||||||
|
|
||||||
self.messages.push(Message {
|
self.messages.push(Message {
|
||||||
excerpt_id,
|
excerpt_id,
|
||||||
role,
|
|
||||||
content: content.clone(),
|
content: content.clone(),
|
||||||
});
|
});
|
||||||
self.messages_metadata.insert(
|
self.messages_metadata.insert(
|
||||||
|
@ -681,9 +691,11 @@ impl Assistant {
|
||||||
.messages
|
.messages
|
||||||
.iter()
|
.iter()
|
||||||
.take(2)
|
.take(2)
|
||||||
.map(|message| RequestMessage {
|
.filter_map(|message| {
|
||||||
role: message.role,
|
Some(RequestMessage {
|
||||||
content: message.content.read(cx).text(),
|
role: self.messages_metadata.get(&message.excerpt_id)?.role,
|
||||||
|
content: message.content.read(cx).text(),
|
||||||
|
})
|
||||||
})
|
})
|
||||||
.chain(Some(RequestMessage {
|
.chain(Some(RequestMessage {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
|
@ -753,27 +765,51 @@ impl AssistantEditor {
|
||||||
{
|
{
|
||||||
let assistant = assistant.clone();
|
let assistant = assistant.clone();
|
||||||
move |_editor, params: editor::RenderExcerptHeaderParams, cx| {
|
move |_editor, params: editor::RenderExcerptHeaderParams, cx| {
|
||||||
|
enum Sender {}
|
||||||
enum ErrorTooltip {}
|
enum ErrorTooltip {}
|
||||||
|
|
||||||
let theme = theme::current(cx);
|
let theme = theme::current(cx);
|
||||||
let style = &theme.assistant;
|
let style = &theme.assistant;
|
||||||
if let Some(metadata) = assistant.read(cx).messages_metadata.get(¶ms.id)
|
let excerpt_id = params.id;
|
||||||
|
if let Some(metadata) = assistant
|
||||||
|
.read(cx)
|
||||||
|
.messages_metadata
|
||||||
|
.get(&excerpt_id)
|
||||||
|
.cloned()
|
||||||
{
|
{
|
||||||
let sender = match metadata.role {
|
let sender = MouseEventHandler::<Sender, _>::new(
|
||||||
Role::User => Label::new("You", style.user_sender.text.clone())
|
params.id.into(),
|
||||||
.contained()
|
cx,
|
||||||
.with_style(style.user_sender.container),
|
|state, _| match metadata.role {
|
||||||
Role::Assistant => {
|
Role::User => {
|
||||||
Label::new("Assistant", style.assistant_sender.text.clone())
|
let style = style.user_sender.style_for(state, false);
|
||||||
.contained()
|
Label::new("You", style.text.clone())
|
||||||
.with_style(style.assistant_sender.container)
|
.contained()
|
||||||
|
.with_style(style.container)
|
||||||
|
}
|
||||||
|
Role::Assistant => {
|
||||||
|
let style = style.assistant_sender.style_for(state, false);
|
||||||
|
Label::new("Assistant", style.text.clone())
|
||||||
|
.contained()
|
||||||
|
.with_style(style.container)
|
||||||
|
}
|
||||||
|
Role::System => {
|
||||||
|
let style = style.system_sender.style_for(state, false);
|
||||||
|
Label::new("System", style.text.clone())
|
||||||
|
.contained()
|
||||||
|
.with_style(style.container)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.with_cursor_style(CursorStyle::PointingHand)
|
||||||
|
.on_down(MouseButton::Left, {
|
||||||
|
let assistant = assistant.clone();
|
||||||
|
move |_, _, cx| {
|
||||||
|
assistant.update(cx, |assistant, cx| {
|
||||||
|
assistant.cycle_message_role(excerpt_id, cx)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Role::System => {
|
});
|
||||||
Label::new("System", style.assistant_sender.text.clone())
|
|
||||||
.contained()
|
|
||||||
.with_style(style.assistant_sender.container)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Flex::row()
|
Flex::row()
|
||||||
.with_child(sender.aligned())
|
.with_child(sender.aligned())
|
||||||
|
@ -786,7 +822,7 @@ impl AssistantEditor {
|
||||||
.with_style(style.sent_at.container)
|
.with_style(style.sent_at.container)
|
||||||
.aligned(),
|
.aligned(),
|
||||||
)
|
)
|
||||||
.with_children(metadata.error.clone().map(|error| {
|
.with_children(metadata.error.map(|error| {
|
||||||
Svg::new("icons/circle_x_mark_12.svg")
|
Svg::new("icons/circle_x_mark_12.svg")
|
||||||
.with_color(style.error_icon.color)
|
.with_color(style.error_icon.color)
|
||||||
.constrained()
|
.constrained()
|
||||||
|
@ -833,21 +869,22 @@ impl AssistantEditor {
|
||||||
self.assistant.update(cx, |assistant, cx| {
|
self.assistant.update(cx, |assistant, cx| {
|
||||||
let editor = self.editor.read(cx);
|
let editor = self.editor.read(cx);
|
||||||
let newest_selection = editor.selections.newest_anchor();
|
let newest_selection = editor.selections.newest_anchor();
|
||||||
let role = if newest_selection.head() == Anchor::min() {
|
let excerpt_id = if newest_selection.head() == Anchor::min() {
|
||||||
assistant.messages.first().map(|message| message.role)
|
assistant.messages.first().map(|message| message.excerpt_id)
|
||||||
} else if newest_selection.head() == Anchor::max() {
|
} else if newest_selection.head() == Anchor::max() {
|
||||||
assistant.messages.last().map(|message| message.role)
|
assistant.messages.last().map(|message| message.excerpt_id)
|
||||||
} else {
|
} else {
|
||||||
assistant
|
Some(newest_selection.head().excerpt_id())
|
||||||
.messages_metadata
|
|
||||||
.get(&newest_selection.head().excerpt_id())
|
|
||||||
.map(|message| message.role)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if role.map_or(false, |role| role == Role::Assistant) {
|
if let Some(excerpt_id) = excerpt_id {
|
||||||
assistant.push_message(Role::User, cx);
|
if let Some(metadata) = assistant.messages_metadata.get(&excerpt_id) {
|
||||||
} else {
|
if metadata.role == Role::User {
|
||||||
assistant.assist(cx);
|
assistant.assist(cx);
|
||||||
|
} else {
|
||||||
|
assistant.push_message(Role::User, cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -967,12 +1004,17 @@ impl AssistantEditor {
|
||||||
let range = cmp::max(message_range.start, selection.range().start)
|
let range = cmp::max(message_range.start, selection.range().start)
|
||||||
..cmp::min(message_range.end, selection.range().end);
|
..cmp::min(message_range.end, selection.range().end);
|
||||||
if !range.is_empty() {
|
if !range.is_empty() {
|
||||||
spanned_messages += 1;
|
if let Some(metadata) = assistant.messages_metadata.get(&message.excerpt_id)
|
||||||
write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
|
{
|
||||||
for chunk in assistant.buffer.read(cx).snapshot(cx).text_for_range(range) {
|
spanned_messages += 1;
|
||||||
copied_text.push_str(&chunk);
|
write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap();
|
||||||
|
for chunk in
|
||||||
|
assistant.buffer.read(cx).snapshot(cx).text_for_range(range)
|
||||||
|
{
|
||||||
|
copied_text.push_str(&chunk);
|
||||||
|
}
|
||||||
|
copied_text.push('\n');
|
||||||
}
|
}
|
||||||
copied_text.push('\n');
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1090,11 +1132,10 @@ impl Item for AssistantEditor {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Message {
|
struct Message {
|
||||||
excerpt_id: ExcerptId,
|
excerpt_id: ExcerptId,
|
||||||
role: Role,
|
|
||||||
content: ModelHandle<Buffer>,
|
content: ModelHandle<Buffer>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct MessageMetadata {
|
struct MessageMetadata {
|
||||||
role: Role,
|
role: Role,
|
||||||
sent_at: DateTime<Local>,
|
sent_at: DateTime<Local>,
|
||||||
|
|
|
@ -974,8 +974,9 @@ pub struct AssistantStyle {
|
||||||
pub container: ContainerStyle,
|
pub container: ContainerStyle,
|
||||||
pub header: ContainerStyle,
|
pub header: ContainerStyle,
|
||||||
pub sent_at: ContainedText,
|
pub sent_at: ContainedText,
|
||||||
pub user_sender: ContainedText,
|
pub user_sender: Interactive<ContainedText>,
|
||||||
pub assistant_sender: ContainedText,
|
pub assistant_sender: Interactive<ContainedText>,
|
||||||
|
pub system_sender: Interactive<ContainedText>,
|
||||||
pub model_info_container: ContainerStyle,
|
pub model_info_container: ContainerStyle,
|
||||||
pub model: Interactive<ContainedText>,
|
pub model: Interactive<ContainedText>,
|
||||||
pub remaining_tokens: ContainedText,
|
pub remaining_tokens: ContainedText,
|
||||||
|
|
|
@ -20,6 +20,9 @@ export default function assistant(colorScheme: ColorScheme) {
|
||||||
assistantSender: {
|
assistantSender: {
|
||||||
...text(layer, "sans", "accent", { size: "sm", weight: "bold" }),
|
...text(layer, "sans", "accent", { size: "sm", weight: "bold" }),
|
||||||
},
|
},
|
||||||
|
systemSender: {
|
||||||
|
...text(layer, "sans", "variant", { size: "sm", weight: "bold" }),
|
||||||
|
},
|
||||||
sentAt: {
|
sentAt: {
|
||||||
margin: { top: 2, left: 8 },
|
margin: { top: 2, left: 8 },
|
||||||
...text(layer, "sans", "default", { size: "2xs" }),
|
...text(layer, "sans", "default", { size: "2xs" }),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue