diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index 88a3f73176..abbb2f20db 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -1,7 +1,7 @@ use anyhow::Result; use gpui::{ prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, - FocusableView, Model, Pixels, Task, View, ViewContext, WeakView, WindowContext, + FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext, }; use language_model::LanguageModelRegistry; use language_model_selector::LanguageModelSelector; @@ -28,6 +28,7 @@ pub struct AssistantPanel { pane: View, thread: Model, message_editor: View, + _subscriptions: Vec, } impl AssistantPanel { @@ -59,11 +60,13 @@ impl AssistantPanel { }); let thread = cx.new_model(Thread::new); + let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())]; Self { pane, thread: thread.clone(), message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)), + _subscriptions: subscriptions, } } } diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index 63f8c869d4..d195682cb3 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -1,14 +1,11 @@ use editor::{Editor, EditorElement, EditorStyle}; -use futures::StreamExt; use gpui::{AppContext, Model, TextStyle, View}; use language_model::{ - LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, MessageContent, Role, StopReason, + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, }; use settings::Settings; use theme::ThemeSettings; use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding}; -use util::ResultExt; use crate::thread::{self, Thread}; use crate::Chat; @@ -71,50 +68,8 @@ impl MessageEditor { editor.clear(cx); }); - let task = cx.spawn(|this, mut cx| async move { - let stream = model.stream_completion(request, &cx); - let stream_completion = async { - let mut events = stream.await?; - let mut stop_reason = StopReason::EndTurn; - - let mut text = String::new(); - - while let Some(event) = events.next().await { - let event = event?; - match event { - LanguageModelCompletionEvent::StartMessage { .. } => {} - LanguageModelCompletionEvent::Stop(reason) => { - stop_reason = reason; - } - LanguageModelCompletionEvent::Text(chunk) => { - text.push_str(&chunk); - } - LanguageModelCompletionEvent::ToolUse(_tool_use) => {} - } - - smol::future::yield_now().await; - } - - anyhow::Ok((stop_reason, text)) - }; - - let result = stream_completion.await; - - this.update(&mut cx, |this, cx| { - if let Some((_stop_reason, text)) = result.log_err() { - this.thread.update(cx, |thread, _cx| { - thread.messages.push(thread::Message { - role: Role::Assistant, - text, - }); - }); - } - }) - .ok(); - }); - - self.thread.update(cx, |thread, _cx| { - thread.pending_completion_tasks.push(task); + self.thread.update(cx, |thread, cx| { + thread.stream_completion(request, model, cx) }); None diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 1553eaabb6..a6c870b456 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,5 +1,11 @@ -use gpui::{ModelContext, Task}; -use language_model::Role; +use std::sync::Arc; + +use futures::StreamExt as _; +use gpui::{EventEmitter, ModelContext, Task}; +use language_model::{ + LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, Role, StopReason, +}; +use util::ResultExt as _; /// A message in a [`Thread`]. pub struct Message { @@ -20,4 +26,64 @@ impl Thread { pending_completion_tasks: Vec::new(), } } + + pub fn stream_completion( + &mut self, + request: LanguageModelRequest, + model: Arc, + cx: &mut ModelContext, + ) { + let task = cx.spawn(|this, mut cx| async move { + let stream = model.stream_completion(request, &cx); + let stream_completion = async { + let mut events = stream.await?; + let mut stop_reason = StopReason::EndTurn; + + while let Some(event) = events.next().await { + let event = event?; + + this.update(&mut cx, |thread, cx| { + match event { + LanguageModelCompletionEvent::StartMessage { .. } => { + thread.messages.push(Message { + role: Role::Assistant, + text: String::new(), + }); + } + LanguageModelCompletionEvent::Stop(reason) => { + stop_reason = reason; + } + LanguageModelCompletionEvent::Text(chunk) => { + if let Some(last_message) = thread.messages.last_mut() { + if last_message.role == Role::Assistant { + last_message.text.push_str(&chunk); + } + } + } + LanguageModelCompletionEvent::ToolUse(_tool_use) => {} + } + + cx.emit(ThreadEvent::StreamedCompletion); + cx.notify(); + })?; + + smol::future::yield_now().await; + } + + anyhow::Ok(stop_reason) + }; + + let result = stream_completion.await; + let _ = result.log_err(); + }); + + self.pending_completion_tasks.push(task); + } } + +#[derive(Debug, Clone)] +pub enum ThreadEvent { + StreamedCompletion, +} + +impl EventEmitter for Thread {}