assistant2: Improve tracking of pending completions (#21186)

This PR improves the tracking of pending completions in `assistant2`
such that we actually remove ones that have been completed.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-25 17:07:55 -05:00 committed by GitHub
parent 2b9250843c
commit cc5daa22bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -6,7 +6,7 @@ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role, StopReason, MessageContent, Role, StopReason,
}; };
use util::ResultExt as _; use util::{post_inc, ResultExt as _};
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum RequestKind { pub enum RequestKind {
@ -19,17 +19,24 @@ pub struct Message {
pub text: String, pub text: String,
} }
struct PendingCompletion {
id: usize,
_task: Task<()>,
}
/// A thread of conversation with the LLM. /// A thread of conversation with the LLM.
pub struct Thread { pub struct Thread {
messages: Vec<Message>, messages: Vec<Message>,
pending_completion_tasks: Vec<Task<()>>, completion_count: usize,
pending_completions: Vec<PendingCompletion>,
} }
impl Thread { impl Thread {
pub fn new(_cx: &mut ModelContext<Self>) -> Self { pub fn new(_cx: &mut ModelContext<Self>) -> Self {
Self { Self {
messages: Vec::new(), messages: Vec::new(),
pending_completion_tasks: Vec::new(), completion_count: 0,
pending_completions: Vec::new(),
} }
} }
@ -79,7 +86,9 @@ impl Thread {
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) { ) {
let task = cx.spawn(|this, mut cx| async move { let pending_completion_id = post_inc(&mut self.completion_count);
let task = cx.spawn(|thread, mut cx| async move {
let stream = model.stream_completion(request, &cx); let stream = model.stream_completion(request, &cx);
let stream_completion = async { let stream_completion = async {
let mut events = stream.await?; let mut events = stream.await?;
@ -88,7 +97,7 @@ impl Thread {
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
let event = event?; let event = event?;
this.update(&mut cx, |thread, cx| { thread.update(&mut cx, |thread, cx| {
match event { match event {
LanguageModelCompletionEvent::StartMessage { .. } => { LanguageModelCompletionEvent::StartMessage { .. } => {
thread.messages.push(Message { thread.messages.push(Message {
@ -116,6 +125,12 @@ impl Thread {
smol::future::yield_now().await; smol::future::yield_now().await;
} }
thread.update(&mut cx, |thread, _cx| {
thread
.pending_completions
.retain(|completion| completion.id != pending_completion_id);
})?;
anyhow::Ok(stop_reason) anyhow::Ok(stop_reason)
}; };
@ -123,7 +138,10 @@ impl Thread {
let _ = result.log_err(); let _ = result.log_err();
}); });
self.pending_completion_tasks.push(task); self.pending_completions.push(PendingCompletion {
id: pending_completion_id,
_task: task,
});
} }
} }