assistant2: Add support for editing the last message sent by the user (#26037)
https://github.com/user-attachments/assets/df46632b-dfeb-4991-ab2e-86829b72be9b Closes #ISSUE Release Notes: - N/A
This commit is contained in:
parent
6685d85f49
commit
f4899d92a4
4 changed files with 304 additions and 18 deletions
|
@ -626,6 +626,15 @@
|
||||||
"enter": "assistant2::Chat"
|
"enter": "assistant2::Chat"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"context": "EditMessageEditor > Editor",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"escape": "menu::Cancel",
|
||||||
|
"enter": "menu::Confirm",
|
||||||
|
"alt-enter": "editor::Newline"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"context": "ContextStrip",
|
"context": "ContextStrip",
|
||||||
"bindings": {
|
"bindings": {
|
||||||
|
|
|
@ -271,6 +271,15 @@
|
||||||
"enter": "assistant2::Chat"
|
"enter": "assistant2::Chat"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"context": "EditMessageEditor > Editor",
|
||||||
|
"use_key_equivalents": true,
|
||||||
|
"bindings": {
|
||||||
|
"escape": "menu::Cancel",
|
||||||
|
"enter": "menu::Confirm",
|
||||||
|
"alt-enter": "editor::Newline"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"context": "ContextStrip",
|
"context": "ContextStrip",
|
||||||
"use_key_equivalents": true,
|
"use_key_equivalents": true,
|
||||||
|
|
|
@ -2,17 +2,18 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use assistant_tool::ToolWorkingSet;
|
use assistant_tool::ToolWorkingSet;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
|
use editor::{Editor, MultiBuffer};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
|
list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity,
|
||||||
ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
|
Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
|
||||||
UnderlineStyle, WeakEntity,
|
TextStyleRefinement, UnderlineStyle, WeakEntity,
|
||||||
};
|
};
|
||||||
use language::LanguageRegistry;
|
use language::{Buffer, LanguageRegistry};
|
||||||
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
|
||||||
use markdown::{Markdown, MarkdownStyle};
|
use markdown::{Markdown, MarkdownStyle};
|
||||||
use settings::Settings as _;
|
use settings::Settings as _;
|
||||||
use theme::ThemeSettings;
|
use theme::ThemeSettings;
|
||||||
use ui::{prelude::*, Disclosure};
|
use ui::{prelude::*, Disclosure, KeyBinding};
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
|
||||||
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
|
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
|
||||||
|
@ -29,11 +30,16 @@ pub struct ActiveThread {
|
||||||
messages: Vec<MessageId>,
|
messages: Vec<MessageId>,
|
||||||
list_state: ListState,
|
list_state: ListState,
|
||||||
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
|
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
|
||||||
|
editing_message: Option<(MessageId, EditMessageState)>,
|
||||||
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
|
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
|
||||||
last_error: Option<ThreadError>,
|
last_error: Option<ThreadError>,
|
||||||
_subscriptions: Vec<Subscription>,
|
_subscriptions: Vec<Subscription>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct EditMessageState {
|
||||||
|
editor: Entity<Editor>,
|
||||||
|
}
|
||||||
|
|
||||||
impl ActiveThread {
|
impl ActiveThread {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
thread: Entity<Thread>,
|
thread: Entity<Thread>,
|
||||||
|
@ -60,11 +66,12 @@ impl ActiveThread {
|
||||||
expanded_tool_uses: HashMap::default(),
|
expanded_tool_uses: HashMap::default(),
|
||||||
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
|
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
|
||||||
let this = cx.entity().downgrade();
|
let this = cx.entity().downgrade();
|
||||||
move |ix, _: &mut Window, cx: &mut App| {
|
move |ix, window: &mut Window, cx: &mut App| {
|
||||||
this.update(cx, |this, cx| this.render_message(ix, cx))
|
this.update(cx, |this, cx| this.render_message(ix, window, cx))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
|
editing_message: None,
|
||||||
last_error: None,
|
last_error: None,
|
||||||
_subscriptions: subscriptions,
|
_subscriptions: subscriptions,
|
||||||
};
|
};
|
||||||
|
@ -117,6 +124,44 @@ impl ActiveThread {
|
||||||
self.messages.push(*id);
|
self.messages.push(*id);
|
||||||
self.list_state.splice(old_len..old_len, 1);
|
self.list_state.splice(old_len..old_len, 1);
|
||||||
|
|
||||||
|
let markdown = self.render_markdown(text.into(), window, cx);
|
||||||
|
self.rendered_messages_by_id.insert(*id, markdown);
|
||||||
|
self.list_state.scroll_to(ListOffset {
|
||||||
|
item_ix: old_len,
|
||||||
|
offset_in_item: Pixels(0.0),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn edited_message(
|
||||||
|
&mut self,
|
||||||
|
id: &MessageId,
|
||||||
|
text: String,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
self.list_state.splice(index..index + 1, 1);
|
||||||
|
let markdown = self.render_markdown(text.into(), window, cx);
|
||||||
|
self.rendered_messages_by_id.insert(*id, markdown);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deleted_message(&mut self, id: &MessageId) {
|
||||||
|
let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
self.messages.remove(index);
|
||||||
|
self.list_state.splice(index..index + 1, 0);
|
||||||
|
self.rendered_messages_by_id.remove(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_markdown(
|
||||||
|
&self,
|
||||||
|
text: SharedString,
|
||||||
|
window: &Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Entity<Markdown> {
|
||||||
let theme_settings = ThemeSettings::get_global(cx);
|
let theme_settings = ThemeSettings::get_global(cx);
|
||||||
let colors = cx.theme().colors();
|
let colors = cx.theme().colors();
|
||||||
let ui_font_size = TextSize::Default.rems(cx);
|
let ui_font_size = TextSize::Default.rems(cx);
|
||||||
|
@ -182,20 +227,15 @@ impl ActiveThread {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let markdown = cx.new(|cx| {
|
cx.new(|cx| {
|
||||||
Markdown::new(
|
Markdown::new(
|
||||||
text.into(),
|
text,
|
||||||
markdown_style,
|
markdown_style,
|
||||||
Some(self.language_registry.clone()),
|
Some(self.language_registry.clone()),
|
||||||
None,
|
None,
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
})
|
||||||
self.rendered_messages_by_id.insert(*id, markdown);
|
|
||||||
self.list_state.scroll_to(ListOffset {
|
|
||||||
item_ix: old_len,
|
|
||||||
offset_in_item: Pixels(0.0),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_thread_event(
|
fn handle_thread_event(
|
||||||
|
@ -241,6 +281,35 @@ impl ActiveThread {
|
||||||
|
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
ThreadEvent::MessageEdited(message_id) => {
|
||||||
|
if let Some(message_text) = self
|
||||||
|
.thread
|
||||||
|
.read(cx)
|
||||||
|
.message(*message_id)
|
||||||
|
.map(|message| message.text.clone())
|
||||||
|
{
|
||||||
|
self.edited_message(message_id, message_text, window, cx);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.thread_store
|
||||||
|
.update(cx, |thread_store, cx| {
|
||||||
|
thread_store.save_thread(&self.thread, cx)
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
ThreadEvent::MessageDeleted(message_id) => {
|
||||||
|
self.deleted_message(message_id);
|
||||||
|
|
||||||
|
self.thread_store
|
||||||
|
.update(cx, |thread_store, cx| {
|
||||||
|
thread_store.save_thread(&self.thread, cx)
|
||||||
|
})
|
||||||
|
.detach_and_log_err(cx);
|
||||||
|
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
ThreadEvent::UsePendingTools => {
|
ThreadEvent::UsePendingTools => {
|
||||||
let pending_tool_uses = self
|
let pending_tool_uses = self
|
||||||
.thread
|
.thread
|
||||||
|
@ -289,7 +358,101 @@ impl ActiveThread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
|
fn start_editing_message(
|
||||||
|
&mut self,
|
||||||
|
message_id: MessageId,
|
||||||
|
message_text: String,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let buffer = cx.new(|cx| {
|
||||||
|
MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
|
||||||
|
});
|
||||||
|
let editor = cx.new(|cx| {
|
||||||
|
let mut editor = Editor::new(
|
||||||
|
editor::EditorMode::AutoHeight { max_lines: 8 },
|
||||||
|
buffer,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
editor.focus_handle(cx).focus(window);
|
||||||
|
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
|
||||||
|
editor
|
||||||
|
});
|
||||||
|
self.editing_message = Some((
|
||||||
|
message_id,
|
||||||
|
EditMessageState {
|
||||||
|
editor: editor.clone(),
|
||||||
|
},
|
||||||
|
));
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
|
||||||
|
self.editing_message.take();
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn confirm_editing_message(
|
||||||
|
&mut self,
|
||||||
|
_: &menu::Confirm,
|
||||||
|
_: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let Some((message_id, state)) = self.editing_message.take() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let edited_text = state.editor.read(cx).text(cx);
|
||||||
|
self.thread.update(cx, |thread, cx| {
|
||||||
|
thread.edit_message(message_id, Role::User, edited_text, cx);
|
||||||
|
for message_id in self.messages_after(message_id) {
|
||||||
|
thread.delete_message(*message_id, cx);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let provider = LanguageModelRegistry::read_global(cx).active_provider();
|
||||||
|
if provider
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |provider| provider.must_accept_terms(cx))
|
||||||
|
{
|
||||||
|
cx.notify();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||||
|
let Some(model) = model_registry.active_model() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
self.thread.update(cx, |thread, cx| {
|
||||||
|
thread.send_to_model(model, RequestKind::Chat, false, cx)
|
||||||
|
});
|
||||||
|
cx.notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
|
||||||
|
self.messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.find(|message_id| {
|
||||||
|
self.thread
|
||||||
|
.read(cx)
|
||||||
|
.message(**message_id)
|
||||||
|
.map_or(false, |message| message.role == Role::User)
|
||||||
|
})
|
||||||
|
.cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
|
||||||
|
self.messages
|
||||||
|
.iter()
|
||||||
|
.position(|id| *id == message_id)
|
||||||
|
.map(|index| &self.messages[index + 1..])
|
||||||
|
.unwrap_or(&[])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
|
||||||
let message_id = self.messages[ix];
|
let message_id = self.messages[ix];
|
||||||
let Some(message) = self.thread.read(cx).message(message_id) else {
|
let Some(message) = self.thread.read(cx).message(message_id) else {
|
||||||
return Empty.into_any();
|
return Empty.into_any();
|
||||||
|
@ -308,8 +471,28 @@ impl ActiveThread {
|
||||||
return Empty.into_any();
|
return Empty.into_any();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let allow_editing_message =
|
||||||
|
message.role == Role::User && self.last_user_message(cx) == Some(message_id);
|
||||||
|
|
||||||
|
let edit_message_editor = self
|
||||||
|
.editing_message
|
||||||
|
.as_ref()
|
||||||
|
.filter(|(id, _)| *id == message_id)
|
||||||
|
.map(|(_, state)| state.editor.clone());
|
||||||
|
|
||||||
let message_content = v_flex()
|
let message_content = v_flex()
|
||||||
.child(div().p_2p5().text_ui(cx).child(markdown.clone()))
|
.child(
|
||||||
|
if let Some(edit_message_editor) = edit_message_editor.clone() {
|
||||||
|
div()
|
||||||
|
.key_context("EditMessageEditor")
|
||||||
|
.on_action(cx.listener(Self::cancel_editing_message))
|
||||||
|
.on_action(cx.listener(Self::confirm_editing_message))
|
||||||
|
.p_2p5()
|
||||||
|
.child(edit_message_editor)
|
||||||
|
} else {
|
||||||
|
div().p_2p5().text_ui(cx).child(markdown.clone())
|
||||||
|
},
|
||||||
|
)
|
||||||
.when_some(context, |parent, context| {
|
.when_some(context, |parent, context| {
|
||||||
if !context.is_empty() {
|
if !context.is_empty() {
|
||||||
parent.child(
|
parent.child(
|
||||||
|
@ -358,6 +541,55 @@ impl ActiveThread {
|
||||||
.size(LabelSize::Small)
|
.size(LabelSize::Small)
|
||||||
.color(Color::Muted),
|
.color(Color::Muted),
|
||||||
),
|
),
|
||||||
|
)
|
||||||
|
.when_some(
|
||||||
|
edit_message_editor.clone(),
|
||||||
|
|this, edit_message_editor| {
|
||||||
|
let focus_handle = edit_message_editor.focus_handle(cx);
|
||||||
|
this.child(
|
||||||
|
h_flex()
|
||||||
|
.gap_1()
|
||||||
|
.child(
|
||||||
|
Button::new("cancel-edit-message", "Cancel")
|
||||||
|
.key_binding(KeyBinding::for_action_in(
|
||||||
|
&menu::Cancel,
|
||||||
|
&focus_handle,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.child(
|
||||||
|
Button::new(
|
||||||
|
"confirm-edit-message",
|
||||||
|
"Regenerate",
|
||||||
|
)
|
||||||
|
.key_binding(KeyBinding::for_action_in(
|
||||||
|
&menu::Confirm,
|
||||||
|
&focus_handle,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.when(
|
||||||
|
edit_message_editor.is_none() && allow_editing_message,
|
||||||
|
|this| {
|
||||||
|
this.child(Button::new("edit-message", "Edit").on_click(
|
||||||
|
cx.listener({
|
||||||
|
let message_text = message.text.clone();
|
||||||
|
move |this, _, window, cx| {
|
||||||
|
this.start_editing_message(
|
||||||
|
message_id,
|
||||||
|
message_text.clone(),
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.child(message_content),
|
.child(message_content),
|
||||||
|
|
|
@ -99,7 +99,13 @@ impl Thread {
|
||||||
tools: Arc<ToolWorkingSet>,
|
tools: Arc<ToolWorkingSet>,
|
||||||
_cx: &mut Context<Self>,
|
_cx: &mut Context<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let next_message_id = MessageId(saved.messages.len());
|
let next_message_id = MessageId(
|
||||||
|
saved
|
||||||
|
.messages
|
||||||
|
.last()
|
||||||
|
.map(|message| message.id.0 + 1)
|
||||||
|
.unwrap_or(0),
|
||||||
|
);
|
||||||
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
@ -229,6 +235,34 @@ impl Thread {
|
||||||
id
|
id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn edit_message(
|
||||||
|
&mut self,
|
||||||
|
id: MessageId,
|
||||||
|
new_role: Role,
|
||||||
|
new_text: String,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> bool {
|
||||||
|
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
message.role = new_role;
|
||||||
|
message.text = new_text;
|
||||||
|
self.touch_updated_at();
|
||||||
|
cx.emit(ThreadEvent::MessageEdited(id));
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
|
||||||
|
let Some(index) = self.messages.iter().position(|message| message.id == id) else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
self.messages.remove(index);
|
||||||
|
self.context_by_message.remove(&id);
|
||||||
|
self.touch_updated_at();
|
||||||
|
cx.emit(ThreadEvent::MessageDeleted(id));
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the representation of this [`Thread`] in a textual form.
|
/// Returns the representation of this [`Thread`] in a textual form.
|
||||||
///
|
///
|
||||||
/// This is the representation we use when attaching a thread as context to another thread.
|
/// This is the representation we use when attaching a thread as context to another thread.
|
||||||
|
@ -567,6 +601,8 @@ pub enum ThreadEvent {
|
||||||
StreamedCompletion,
|
StreamedCompletion,
|
||||||
StreamedAssistantText(MessageId, String),
|
StreamedAssistantText(MessageId, String),
|
||||||
MessageAdded(MessageId),
|
MessageAdded(MessageId),
|
||||||
|
MessageEdited(MessageId),
|
||||||
|
MessageDeleted(MessageId),
|
||||||
SummaryChanged,
|
SummaryChanged,
|
||||||
UsePendingTools,
|
UsePendingTools,
|
||||||
ToolFinished {
|
ToolFinished {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue