assistant2: Improve tracking of pending completions (#21186)
This PR improves the tracking of pending completions in `assistant2` such that we actually remove ones that have been completed. Release Notes: - N/A
This commit is contained in:
parent
2b9250843c
commit
cc5daa22bd
1 changed files with 24 additions and 6 deletions
|
@ -6,7 +6,7 @@ use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
MessageContent, Role, StopReason,
|
MessageContent, Role, StopReason,
|
||||||
};
|
};
|
||||||
use util::ResultExt as _;
|
use util::{post_inc, ResultExt as _};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub enum RequestKind {
|
pub enum RequestKind {
|
||||||
|
@ -19,17 +19,24 @@ pub struct Message {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct PendingCompletion {
|
||||||
|
id: usize,
|
||||||
|
_task: Task<()>,
|
||||||
|
}
|
||||||
|
|
||||||
/// A thread of conversation with the LLM.
|
/// A thread of conversation with the LLM.
|
||||||
pub struct Thread {
|
pub struct Thread {
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
pending_completion_tasks: Vec<Task<()>>,
|
completion_count: usize,
|
||||||
|
pending_completions: Vec<PendingCompletion>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Thread {
|
impl Thread {
|
||||||
pub fn new(_cx: &mut ModelContext<Self>) -> Self {
|
pub fn new(_cx: &mut ModelContext<Self>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
messages: Vec::new(),
|
messages: Vec::new(),
|
||||||
pending_completion_tasks: Vec::new(),
|
completion_count: 0,
|
||||||
|
pending_completions: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,7 +86,9 @@ impl Thread {
|
||||||
model: Arc<dyn LanguageModel>,
|
model: Arc<dyn LanguageModel>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) {
|
) {
|
||||||
let task = cx.spawn(|this, mut cx| async move {
|
let pending_completion_id = post_inc(&mut self.completion_count);
|
||||||
|
|
||||||
|
let task = cx.spawn(|thread, mut cx| async move {
|
||||||
let stream = model.stream_completion(request, &cx);
|
let stream = model.stream_completion(request, &cx);
|
||||||
let stream_completion = async {
|
let stream_completion = async {
|
||||||
let mut events = stream.await?;
|
let mut events = stream.await?;
|
||||||
|
@ -88,7 +97,7 @@ impl Thread {
|
||||||
while let Some(event) = events.next().await {
|
while let Some(event) = events.next().await {
|
||||||
let event = event?;
|
let event = event?;
|
||||||
|
|
||||||
this.update(&mut cx, |thread, cx| {
|
thread.update(&mut cx, |thread, cx| {
|
||||||
match event {
|
match event {
|
||||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||||
thread.messages.push(Message {
|
thread.messages.push(Message {
|
||||||
|
@ -116,6 +125,12 @@ impl Thread {
|
||||||
smol::future::yield_now().await;
|
smol::future::yield_now().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread.update(&mut cx, |thread, _cx| {
|
||||||
|
thread
|
||||||
|
.pending_completions
|
||||||
|
.retain(|completion| completion.id != pending_completion_id);
|
||||||
|
})?;
|
||||||
|
|
||||||
anyhow::Ok(stop_reason)
|
anyhow::Ok(stop_reason)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -123,7 +138,10 @@ impl Thread {
|
||||||
let _ = result.log_err();
|
let _ = result.log_err();
|
||||||
});
|
});
|
||||||
|
|
||||||
self.pending_completion_tasks.push(task);
|
self.pending_completions.push(PendingCompletion {
|
||||||
|
id: pending_completion_id,
|
||||||
|
_task: task,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue