Avoid including pending or errored messages on assist
This commit is contained in:
parent
cb55356106
commit
8673b0b75b
3 changed files with 94 additions and 56 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -114,6 +114,7 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"settings",
|
"settings",
|
||||||
|
"smol",
|
||||||
"theme",
|
"theme",
|
||||||
"tiktoken-rs",
|
"tiktoken-rs",
|
||||||
"util",
|
"util",
|
||||||
|
|
|
@ -28,6 +28,7 @@ isahc.workspace = true
|
||||||
schemars.workspace = true
|
schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
smol.workspace = true
|
||||||
tiktoken-rs = "0.4"
|
tiktoken-rs = "0.4"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
@ -518,7 +518,7 @@ impl Assistant {
|
||||||
MessageMetadata {
|
MessageMetadata {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
sent_at: Local::now(),
|
sent_at: Local::now(),
|
||||||
error: None,
|
status: MessageStatus::Done,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -595,6 +595,7 @@ impl Assistant {
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Vec<MessageAnchor> {
|
) -> Vec<MessageAnchor> {
|
||||||
let mut user_messages = Vec::new();
|
let mut user_messages = Vec::new();
|
||||||
|
let mut tasks = Vec::new();
|
||||||
for selected_message_id in selected_messages {
|
for selected_message_id in selected_messages {
|
||||||
let selected_message_role =
|
let selected_message_role =
|
||||||
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
|
if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
|
||||||
|
@ -604,15 +605,22 @@ impl Assistant {
|
||||||
};
|
};
|
||||||
|
|
||||||
if selected_message_role == Role::Assistant {
|
if selected_message_role == Role::Assistant {
|
||||||
let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else {
|
if let Some(user_message) = self.insert_message_after(
|
||||||
continue;
|
selected_message_id,
|
||||||
};
|
Role::User,
|
||||||
|
MessageStatus::Done,
|
||||||
|
cx,
|
||||||
|
) {
|
||||||
user_messages.push(user_message);
|
user_messages.push(user_message);
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let request = OpenAIRequest {
|
let request = OpenAIRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
messages: self
|
messages: self
|
||||||
.messages(cx)
|
.messages(cx)
|
||||||
|
.filter(|message| matches!(message.status, MessageStatus::Done))
|
||||||
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
||||||
.chain(Some(RequestMessage {
|
.chain(Some(RequestMessage {
|
||||||
role: Role::System,
|
role: Role::System,
|
||||||
|
@ -628,10 +636,15 @@ impl Assistant {
|
||||||
let Some(api_key) = self.api_key.borrow().clone() else { continue };
|
let Some(api_key) = self.api_key.borrow().clone() else { continue };
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||||
let assistant_message = self
|
let assistant_message = self
|
||||||
.insert_message_after(selected_message_id, Role::Assistant, cx)
|
.insert_message_after(
|
||||||
|
selected_message_id,
|
||||||
|
Role::Assistant,
|
||||||
|
MessageStatus::Pending,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let task = cx.spawn_weak({
|
tasks.push(cx.spawn_weak({
|
||||||
|this, mut cx| async move {
|
|this, mut cx| async move {
|
||||||
let assistant_message_id = assistant_message.id;
|
let assistant_message_id = assistant_message.id;
|
||||||
let stream_completion = async {
|
let stream_completion = async {
|
||||||
|
@ -648,16 +661,15 @@ impl Assistant {
|
||||||
|message| message.id == assistant_message_id,
|
|message| message.id == assistant_message_id,
|
||||||
)?;
|
)?;
|
||||||
this.buffer.update(cx, |buffer, cx| {
|
this.buffer.update(cx, |buffer, cx| {
|
||||||
let offset = if message_ix + 1
|
let offset = this.message_anchors[message_ix + 1..]
|
||||||
== this.message_anchors.len()
|
.iter()
|
||||||
{
|
.find(|message| message.start.is_valid(buffer))
|
||||||
buffer.len()
|
.map_or(buffer.len(), |message| {
|
||||||
} else {
|
message
|
||||||
this.message_anchors[message_ix + 1]
|
|
||||||
.start
|
.start
|
||||||
.to_offset(buffer)
|
.to_offset(buffer)
|
||||||
.saturating_sub(1)
|
.saturating_sub(1)
|
||||||
};
|
});
|
||||||
buffer.edit([(offset..offset, text)], None, cx);
|
buffer.edit([(offset..offset, text)], None, cx);
|
||||||
});
|
});
|
||||||
cx.emit(AssistantEvent::StreamedCompletion);
|
cx.emit(AssistantEvent::StreamedCompletion);
|
||||||
|
@ -665,6 +677,7 @@ impl Assistant {
|
||||||
Some(())
|
Some(())
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
smol::future::yield_now().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.upgrade(&cx)
|
this.upgrade(&cx)
|
||||||
|
@ -682,25 +695,34 @@ impl Assistant {
|
||||||
let result = stream_completion.await;
|
let result = stream_completion.await;
|
||||||
if let Some(this) = this.upgrade(&cx) {
|
if let Some(this) = this.upgrade(&cx) {
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
if let Err(error) = result {
|
|
||||||
if let Some(metadata) =
|
if let Some(metadata) =
|
||||||
this.messages_metadata.get_mut(&assistant_message.id)
|
this.messages_metadata.get_mut(&assistant_message.id)
|
||||||
{
|
{
|
||||||
metadata.error = Some(error.to_string().trim().into());
|
match result {
|
||||||
|
Ok(_) => {
|
||||||
|
metadata.status = MessageStatus::Done;
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
metadata.status = MessageStatus::Error(
|
||||||
|
error.to_string().trim().into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tasks.is_empty() {
|
||||||
self.pending_completions.push(PendingCompletion {
|
self.pending_completions.push(PendingCompletion {
|
||||||
id: post_inc(&mut self.completion_count),
|
id: post_inc(&mut self.completion_count),
|
||||||
_task: task,
|
_tasks: tasks,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
user_messages
|
user_messages
|
||||||
}
|
}
|
||||||
|
@ -723,6 +745,7 @@ impl Assistant {
|
||||||
&mut self,
|
&mut self,
|
||||||
message_id: MessageId,
|
message_id: MessageId,
|
||||||
role: Role,
|
role: Role,
|
||||||
|
status: MessageStatus,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Option<MessageAnchor> {
|
) -> Option<MessageAnchor> {
|
||||||
if let Some(prev_message_ix) = self
|
if let Some(prev_message_ix) = self
|
||||||
|
@ -749,7 +772,7 @@ impl Assistant {
|
||||||
MessageMetadata {
|
MessageMetadata {
|
||||||
role,
|
role,
|
||||||
sent_at: Local::now(),
|
sent_at: Local::now(),
|
||||||
error: None,
|
status,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
cx.emit(AssistantEvent::MessagesEdited);
|
cx.emit(AssistantEvent::MessagesEdited);
|
||||||
|
@ -808,7 +831,7 @@ impl Assistant {
|
||||||
MessageMetadata {
|
MessageMetadata {
|
||||||
role,
|
role,
|
||||||
sent_at: Local::now(),
|
sent_at: Local::now(),
|
||||||
error: None,
|
status: MessageStatus::Done,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -850,7 +873,7 @@ impl Assistant {
|
||||||
MessageMetadata {
|
MessageMetadata {
|
||||||
role,
|
role,
|
||||||
sent_at: Local::now(),
|
sent_at: Local::now(),
|
||||||
error: None,
|
status: MessageStatus::Done,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
(Some(selection), Some(suffix))
|
(Some(selection), Some(suffix))
|
||||||
|
@ -970,7 +993,7 @@ impl Assistant {
|
||||||
anchor: message_anchor.start,
|
anchor: message_anchor.start,
|
||||||
role: metadata.role,
|
role: metadata.role,
|
||||||
sent_at: metadata.sent_at,
|
sent_at: metadata.sent_at,
|
||||||
error: metadata.error.clone(),
|
status: metadata.status.clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
|
@ -980,7 +1003,7 @@ impl Assistant {
|
||||||
|
|
||||||
struct PendingCompletion {
|
struct PendingCompletion {
|
||||||
id: usize,
|
id: usize,
|
||||||
_task: Task<()>,
|
_tasks: Vec<Task<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum AssistantEditorEvent {
|
enum AssistantEditorEvent {
|
||||||
|
@ -1239,7 +1262,9 @@ impl AssistantEditor {
|
||||||
.with_style(style.sent_at.container)
|
.with_style(style.sent_at.container)
|
||||||
.aligned(),
|
.aligned(),
|
||||||
)
|
)
|
||||||
.with_children(message.error.as_ref().map(|error| {
|
.with_children(
|
||||||
|
if let MessageStatus::Error(error) = &message.status {
|
||||||
|
Some(
|
||||||
Svg::new("icons/circle_x_mark_12.svg")
|
Svg::new("icons/circle_x_mark_12.svg")
|
||||||
.with_color(style.error_icon.color)
|
.with_color(style.error_icon.color)
|
||||||
.constrained()
|
.constrained()
|
||||||
|
@ -1253,8 +1278,12 @@ impl AssistantEditor {
|
||||||
theme.tooltip.clone(),
|
theme.tooltip.clone(),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
.aligned()
|
.aligned(),
|
||||||
}))
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
)
|
||||||
.aligned()
|
.aligned()
|
||||||
.left()
|
.left()
|
||||||
.contained()
|
.contained()
|
||||||
|
@ -1502,7 +1531,14 @@ struct MessageAnchor {
|
||||||
struct MessageMetadata {
|
struct MessageMetadata {
|
||||||
role: Role,
|
role: Role,
|
||||||
sent_at: DateTime<Local>,
|
sent_at: DateTime<Local>,
|
||||||
error: Option<Arc<str>>,
|
status: MessageStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
enum MessageStatus {
|
||||||
|
Pending,
|
||||||
|
Done,
|
||||||
|
Error(Arc<str>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -1513,7 +1549,7 @@ pub struct Message {
|
||||||
anchor: language::Anchor,
|
anchor: language::Anchor,
|
||||||
role: Role,
|
role: Role,
|
||||||
sent_at: DateTime<Local>,
|
sent_at: DateTime<Local>,
|
||||||
error: Option<Arc<str>>,
|
status: MessageStatus,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
|
@ -1632,7 +1668,7 @@ mod tests {
|
||||||
|
|
||||||
let message_2 = assistant.update(cx, |assistant, cx| {
|
let message_2 = assistant.update(cx, |assistant, cx| {
|
||||||
assistant
|
assistant
|
||||||
.insert_message_after(message_1.id, Role::Assistant, cx)
|
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1656,7 +1692,7 @@ mod tests {
|
||||||
|
|
||||||
let message_3 = assistant.update(cx, |assistant, cx| {
|
let message_3 = assistant.update(cx, |assistant, cx| {
|
||||||
assistant
|
assistant
|
||||||
.insert_message_after(message_2.id, Role::User, cx)
|
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1670,7 +1706,7 @@ mod tests {
|
||||||
|
|
||||||
let message_4 = assistant.update(cx, |assistant, cx| {
|
let message_4 = assistant.update(cx, |assistant, cx| {
|
||||||
assistant
|
assistant
|
||||||
.insert_message_after(message_2.id, Role::User, cx)
|
.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1731,7 +1767,7 @@ mod tests {
|
||||||
// Ensure we can still insert after a merged message.
|
// Ensure we can still insert after a merged message.
|
||||||
let message_5 = assistant.update(cx, |assistant, cx| {
|
let message_5 = assistant.update(cx, |assistant, cx| {
|
||||||
assistant
|
assistant
|
||||||
.insert_message_after(message_1.id, Role::System, cx)
|
.insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1852,14 +1888,14 @@ mod tests {
|
||||||
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
|
buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
|
||||||
let message_2 = assistant
|
let message_2 = assistant
|
||||||
.update(cx, |assistant, cx| {
|
.update(cx, |assistant, cx| {
|
||||||
assistant.insert_message_after(message_1.id, Role::User, cx)
|
assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
|
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
|
||||||
|
|
||||||
let message_3 = assistant
|
let message_3 = assistant
|
||||||
.update(cx, |assistant, cx| {
|
.update(cx, |assistant, cx| {
|
||||||
assistant.insert_message_after(message_2.id, Role::User, cx)
|
assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
|
buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue