Avoid including pending or errored messages on assist

This commit is contained in:
Antonio Scandurra 2023-06-20 11:59:51 +02:00
parent cb55356106
commit 8673b0b75b
3 changed files with 94 additions and 56 deletions

1
Cargo.lock generated
View file

@ -114,6 +114,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"settings", "settings",
"smol",
"theme", "theme",
"tiktoken-rs", "tiktoken-rs",
"util", "util",

View file

@ -28,6 +28,7 @@ isahc.workspace = true
schemars.workspace = true schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
smol.workspace = true
tiktoken-rs = "0.4" tiktoken-rs = "0.4"
[dev-dependencies] [dev-dependencies]

View file

@ -518,7 +518,7 @@ impl Assistant {
MessageMetadata { MessageMetadata {
role: Role::User, role: Role::User,
sent_at: Local::now(), sent_at: Local::now(),
error: None, status: MessageStatus::Done,
}, },
); );
@ -595,6 +595,7 @@ impl Assistant {
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Vec<MessageAnchor> { ) -> Vec<MessageAnchor> {
let mut user_messages = Vec::new(); let mut user_messages = Vec::new();
let mut tasks = Vec::new();
for selected_message_id in selected_messages { for selected_message_id in selected_messages {
let selected_message_role = let selected_message_role =
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) { if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
@ -604,15 +605,22 @@ impl Assistant {
}; };
if selected_message_role == Role::Assistant { if selected_message_role == Role::Assistant {
let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else { if let Some(user_message) = self.insert_message_after(
continue; selected_message_id,
}; Role::User,
MessageStatus::Done,
cx,
) {
user_messages.push(user_message); user_messages.push(user_message);
} else {
continue;
}
} else { } else {
let request = OpenAIRequest { let request = OpenAIRequest {
model: self.model.clone(), model: self.model.clone(),
messages: self messages: self
.messages(cx) .messages(cx)
.filter(|message| matches!(message.status, MessageStatus::Done))
.map(|message| message.to_open_ai_message(self.buffer.read(cx))) .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.chain(Some(RequestMessage { .chain(Some(RequestMessage {
role: Role::System, role: Role::System,
@ -628,10 +636,15 @@ impl Assistant {
let Some(api_key) = self.api_key.borrow().clone() else { continue }; let Some(api_key) = self.api_key.borrow().clone() else { continue };
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = stream_completion(api_key, cx.background().clone(), request);
let assistant_message = self let assistant_message = self
.insert_message_after(selected_message_id, Role::Assistant, cx) .insert_message_after(
selected_message_id,
Role::Assistant,
MessageStatus::Pending,
cx,
)
.unwrap(); .unwrap();
let task = cx.spawn_weak({ tasks.push(cx.spawn_weak({
|this, mut cx| async move { |this, mut cx| async move {
let assistant_message_id = assistant_message.id; let assistant_message_id = assistant_message.id;
let stream_completion = async { let stream_completion = async {
@ -648,16 +661,15 @@ impl Assistant {
|message| message.id == assistant_message_id, |message| message.id == assistant_message_id,
)?; )?;
this.buffer.update(cx, |buffer, cx| { this.buffer.update(cx, |buffer, cx| {
let offset = if message_ix + 1 let offset = this.message_anchors[message_ix + 1..]
== this.message_anchors.len() .iter()
{ .find(|message| message.start.is_valid(buffer))
buffer.len() .map_or(buffer.len(), |message| {
} else { message
this.message_anchors[message_ix + 1]
.start .start
.to_offset(buffer) .to_offset(buffer)
.saturating_sub(1) .saturating_sub(1)
}; });
buffer.edit([(offset..offset, text)], None, cx); buffer.edit([(offset..offset, text)], None, cx);
}); });
cx.emit(AssistantEvent::StreamedCompletion); cx.emit(AssistantEvent::StreamedCompletion);
@ -665,6 +677,7 @@ impl Assistant {
Some(()) Some(())
}); });
} }
smol::future::yield_now().await;
} }
this.upgrade(&cx) this.upgrade(&cx)
@ -682,25 +695,34 @@ impl Assistant {
let result = stream_completion.await; let result = stream_completion.await;
if let Some(this) = this.upgrade(&cx) { if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
if let Err(error) = result {
if let Some(metadata) = if let Some(metadata) =
this.messages_metadata.get_mut(&assistant_message.id) this.messages_metadata.get_mut(&assistant_message.id)
{ {
metadata.error = Some(error.to_string().trim().into()); match result {
Ok(_) => {
metadata.status = MessageStatus::Done;
}
Err(error) => {
metadata.status = MessageStatus::Error(
error.to_string().trim().into(),
);
}
}
cx.notify(); cx.notify();
} }
}
}); });
} }
} }
}); }));
}
}
if !tasks.is_empty() {
self.pending_completions.push(PendingCompletion { self.pending_completions.push(PendingCompletion {
id: post_inc(&mut self.completion_count), id: post_inc(&mut self.completion_count),
_task: task, _tasks: tasks,
}); });
} }
}
user_messages user_messages
} }
@ -723,6 +745,7 @@ impl Assistant {
&mut self, &mut self,
message_id: MessageId, message_id: MessageId,
role: Role, role: Role,
status: MessageStatus,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Option<MessageAnchor> { ) -> Option<MessageAnchor> {
if let Some(prev_message_ix) = self if let Some(prev_message_ix) = self
@ -749,7 +772,7 @@ impl Assistant {
MessageMetadata { MessageMetadata {
role, role,
sent_at: Local::now(), sent_at: Local::now(),
error: None, status,
}, },
); );
cx.emit(AssistantEvent::MessagesEdited); cx.emit(AssistantEvent::MessagesEdited);
@ -808,7 +831,7 @@ impl Assistant {
MessageMetadata { MessageMetadata {
role, role,
sent_at: Local::now(), sent_at: Local::now(),
error: None, status: MessageStatus::Done,
}, },
); );
@ -850,7 +873,7 @@ impl Assistant {
MessageMetadata { MessageMetadata {
role, role,
sent_at: Local::now(), sent_at: Local::now(),
error: None, status: MessageStatus::Done,
}, },
); );
(Some(selection), Some(suffix)) (Some(selection), Some(suffix))
@ -970,7 +993,7 @@ impl Assistant {
anchor: message_anchor.start, anchor: message_anchor.start,
role: metadata.role, role: metadata.role,
sent_at: metadata.sent_at, sent_at: metadata.sent_at,
error: metadata.error.clone(), status: metadata.status.clone(),
}); });
} }
None None
@ -980,7 +1003,7 @@ impl Assistant {
struct PendingCompletion { struct PendingCompletion {
id: usize, id: usize,
_task: Task<()>, _tasks: Vec<Task<()>>,
} }
enum AssistantEditorEvent { enum AssistantEditorEvent {
@ -1239,7 +1262,9 @@ impl AssistantEditor {
.with_style(style.sent_at.container) .with_style(style.sent_at.container)
.aligned(), .aligned(),
) )
.with_children(message.error.as_ref().map(|error| { .with_children(
if let MessageStatus::Error(error) = &message.status {
Some(
Svg::new("icons/circle_x_mark_12.svg") Svg::new("icons/circle_x_mark_12.svg")
.with_color(style.error_icon.color) .with_color(style.error_icon.color)
.constrained() .constrained()
@ -1253,8 +1278,12 @@ impl AssistantEditor {
theme.tooltip.clone(), theme.tooltip.clone(),
cx, cx,
) )
.aligned() .aligned(),
})) )
} else {
None
},
)
.aligned() .aligned()
.left() .left()
.contained() .contained()
@ -1502,7 +1531,14 @@ struct MessageAnchor {
struct MessageMetadata { struct MessageMetadata {
role: Role, role: Role,
sent_at: DateTime<Local>, sent_at: DateTime<Local>,
error: Option<Arc<str>>, status: MessageStatus,
}
#[derive(Clone, Debug)]
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -1513,7 +1549,7 @@ pub struct Message {
anchor: language::Anchor, anchor: language::Anchor,
role: Role, role: Role,
sent_at: DateTime<Local>, sent_at: DateTime<Local>,
error: Option<Arc<str>>, status: MessageStatus,
} }
impl Message { impl Message {
@ -1632,7 +1668,7 @@ mod tests {
let message_2 = assistant.update(cx, |assistant, cx| { let message_2 = assistant.update(cx, |assistant, cx| {
assistant assistant
.insert_message_after(message_1.id, Role::Assistant, cx) .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
.unwrap() .unwrap()
}); });
assert_eq!( assert_eq!(
@ -1656,7 +1692,7 @@ mod tests {
let message_3 = assistant.update(cx, |assistant, cx| { let message_3 = assistant.update(cx, |assistant, cx| {
assistant assistant
.insert_message_after(message_2.id, Role::User, cx) .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap() .unwrap()
}); });
assert_eq!( assert_eq!(
@ -1670,7 +1706,7 @@ mod tests {
let message_4 = assistant.update(cx, |assistant, cx| { let message_4 = assistant.update(cx, |assistant, cx| {
assistant assistant
.insert_message_after(message_2.id, Role::User, cx) .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
.unwrap() .unwrap()
}); });
assert_eq!( assert_eq!(
@ -1731,7 +1767,7 @@ mod tests {
// Ensure we can still insert after a merged message. // Ensure we can still insert after a merged message.
let message_5 = assistant.update(cx, |assistant, cx| { let message_5 = assistant.update(cx, |assistant, cx| {
assistant assistant
.insert_message_after(message_1.id, Role::System, cx) .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
.unwrap() .unwrap()
}); });
assert_eq!( assert_eq!(
@ -1852,14 +1888,14 @@ mod tests {
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
let message_2 = assistant let message_2 = assistant
.update(cx, |assistant, cx| { .update(cx, |assistant, cx| {
assistant.insert_message_after(message_1.id, Role::User, cx) assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
}) })
.unwrap(); .unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
let message_3 = assistant let message_3 = assistant
.update(cx, |assistant, cx| { .update(cx, |assistant, cx| {
assistant.insert_message_after(message_2.id, Role::User, cx) assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
}) })
.unwrap(); .unwrap();
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));