diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 4a8319015f..7f387194b3 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -459,7 +459,7 @@ impl Assistant { api_key, buffer, }; - this.push_message(Role::User, cx); + this.insert_message_after(ExcerptId::max(), Role::User, cx); this.count_remaining_tokens(cx); this } @@ -498,7 +498,7 @@ impl Assistant { }) .collect::>(); let model = self.model.clone(); - self.pending_token_count = cx.spawn(|this, mut cx| { + self.pending_token_count = cx.spawn_weak(|this, mut cx| { async move { cx.background().timer(Duration::from_millis(200)).await; let token_count = cx @@ -506,11 +506,13 @@ impl Assistant { .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) }) .await?; - this.update(&mut cx, |this, cx| { - this.max_token_count = tiktoken_rs::model::get_context_size(&this.model); - this.token_count = Some(token_count); - cx.notify() - }); + this.upgrade(&cx) + .ok_or_else(|| anyhow!("assistant was dropped"))? + .update(&mut cx, |this, cx| { + this.max_token_count = tiktoken_rs::model::get_context_size(&this.model); + this.token_count = Some(token_count); + cx.notify() + }); anyhow::Ok(()) } .log_err() @@ -547,9 +549,10 @@ impl Assistant { let api_key = self.api_key.borrow().clone(); if let Some(api_key) = api_key { let stream = stream_completion(api_key, cx.background().clone(), request); - let (excerpt_id, content) = self.push_message(Role::Assistant, cx); - self.push_message(Role::User, cx); - let task = cx.spawn(|this, mut cx| async move { + let (excerpt_id, content) = + self.insert_message_after(ExcerptId::max(), Role::Assistant, cx); + self.insert_message_after(ExcerptId::max(), Role::User, cx); + let task = cx.spawn_weak(|this, mut cx| async move { let stream_completion = async { let mut messages = stream.await?; @@ -564,22 +567,26 @@ impl Assistant { } } - this.update(&mut cx, |this, cx| { - this.pending_completions - .retain(|completion| completion.id != this.completion_count); - this.summarize(cx); - }); + this.upgrade(&cx) + .ok_or_else(|| anyhow!("assistant was dropped"))? + .update(&mut cx, |this, cx| { + this.pending_completions + .retain(|completion| completion.id != this.completion_count); + this.summarize(cx); + }); anyhow::Ok(()) }; if let Err(error) = stream_completion.await { - this.update(&mut cx, |this, cx| { - if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) { - metadata.error = Some(error.to_string().trim().into()); - cx.notify(); - } - }) + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) { + metadata.error = Some(error.to_string().trim().into()); + cx.notify(); + } + }); + } } }); @@ -632,8 +639,9 @@ impl Assistant { } } - fn push_message( + fn insert_message_after( &mut self, + excerpt_id: ExcerptId, role: Role, cx: &mut ModelContext, ) -> (ExcerptId, ModelHandle) { @@ -654,9 +662,10 @@ impl Assistant { buffer.set_language_registry(self.languages.clone()); buffer }); - let excerpt_id = self.buffer.update(cx, |buffer, cx| { + let new_excerpt_id = self.buffer.update(cx, |buffer, cx| { buffer - .push_excerpts( + .insert_excerpts_after( + excerpt_id, content.clone(), vec![ExcerptRange { context: 0..0, @@ -668,19 +677,27 @@ impl Assistant { .unwrap() }); - self.messages.push(Message { - excerpt_id, - content: content.clone(), - }); + let ix = self + .messages + .iter() + .position(|message| message.excerpt_id == excerpt_id) + .map_or(self.messages.len(), |ix| ix + 1); + self.messages.insert( + ix, + Message { + excerpt_id: new_excerpt_id, + content: content.clone(), + }, + ); self.messages_metadata.insert( - excerpt_id, + new_excerpt_id, MessageMetadata { role, sent_at: Local::now(), error: None, }, ); - (excerpt_id, content) + (new_excerpt_id, content) } fn summarize(&mut self, cx: &mut ModelContext) { @@ -882,7 +899,7 @@ impl AssistantEditor { if metadata.role == Role::User { assistant.assist(cx); } else { - assistant.push_message(Role::User, cx); + assistant.insert_message_after(excerpt_id, Role::User, cx); } } } @@ -1227,3 +1244,28 @@ async fn stream_completion( } } } + +#[cfg(test)] +mod tests { + use super::*; + use gpui::AppContext; + + #[gpui::test] + fn test_inserting_and_removing_messages(cx: &mut AppContext) { + let registry = Arc::new(LanguageRegistry::test()); + + cx.add_model(|cx| { + let mut assistant = Assistant::new(Default::default(), registry, cx); + let (excerpt_1, _) = + assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx); + let (excerpt_2, _) = assistant.insert_message_after(excerpt_1, Role::User, cx); + let (excerpt_3, _) = assistant.insert_message_after(excerpt_1, Role::User, cx); + assistant.remove_empty_messages( + HashSet::from_iter([excerpt_2, excerpt_3]), + Default::default(), + cx, + ); + assistant + }); + } +}