diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index a433c10267..c1df6c76d3 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -6,7 +6,7 @@ use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, StopReason, }; -use util::ResultExt as _; +use util::{post_inc, ResultExt as _}; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -19,17 +19,24 @@ pub struct Message { pub text: String, } +struct PendingCompletion { + id: usize, + _task: Task<()>, +} + /// A thread of conversation with the LLM. pub struct Thread { messages: Vec, - pending_completion_tasks: Vec>, + completion_count: usize, + pending_completions: Vec, } impl Thread { pub fn new(_cx: &mut ModelContext) -> Self { Self { messages: Vec::new(), - pending_completion_tasks: Vec::new(), + completion_count: 0, + pending_completions: Vec::new(), } } @@ -79,7 +86,9 @@ impl Thread { model: Arc, cx: &mut ModelContext, ) { - let task = cx.spawn(|this, mut cx| async move { + let pending_completion_id = post_inc(&mut self.completion_count); + + let task = cx.spawn(|thread, mut cx| async move { let stream = model.stream_completion(request, &cx); let stream_completion = async { let mut events = stream.await?; @@ -88,7 +97,7 @@ impl Thread { while let Some(event) = events.next().await { let event = event?; - this.update(&mut cx, |thread, cx| { + thread.update(&mut cx, |thread, cx| { match event { LanguageModelCompletionEvent::StartMessage { .. } => { thread.messages.push(Message { @@ -116,6 +125,12 @@ impl Thread { smol::future::yield_now().await; } + thread.update(&mut cx, |thread, _cx| { + thread + .pending_completions + .retain(|completion| completion.id != pending_completion_id); + })?; + anyhow::Ok(stop_reason) }; @@ -123,7 +138,10 @@ impl Thread { let _ = result.log_err(); }); - self.pending_completion_tasks.push(task); + self.pending_completions.push(PendingCompletion { + id: pending_completion_id, + _task: task, + }); } }