assistant2: Stream in completion text (#21182)

This PR makes it so that the completion text streams into the message
list rather than being buffered until the end.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-25 16:13:27 -05:00 committed by GitHub
parent 91a565f5fa
commit 9ee1aba80a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 75 additions and 51 deletions

View file

@ -1,7 +1,7 @@
use anyhow::Result;
use gpui::{
prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
FocusableView, Model, Pixels, Task, View, ViewContext, WeakView, WindowContext,
FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext,
};
use language_model::LanguageModelRegistry;
use language_model_selector::LanguageModelSelector;
@ -28,6 +28,7 @@ pub struct AssistantPanel {
pane: View<Pane>,
thread: Model<Thread>,
message_editor: View<MessageEditor>,
_subscriptions: Vec<Subscription>,
}
impl AssistantPanel {
@ -59,11 +60,13 @@ impl AssistantPanel {
});
let thread = cx.new_model(Thread::new);
let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
Self {
pane,
thread: thread.clone(),
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
_subscriptions: subscriptions,
}
}
}

View file

@ -1,14 +1,11 @@
use editor::{Editor, EditorElement, EditorStyle};
use futures::StreamExt;
use gpui::{AppContext, Model, TextStyle, View};
use language_model::{
LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, MessageContent, Role, StopReason,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
use settings::Settings;
use theme::ThemeSettings;
use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
use util::ResultExt;
use crate::thread::{self, Thread};
use crate::Chat;
@ -71,50 +68,8 @@ impl MessageEditor {
editor.clear(cx);
});
let task = cx.spawn(|this, mut cx| async move {
let stream = model.stream_completion(request, &cx);
let stream_completion = async {
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
let mut text = String::new();
while let Some(event) = events.next().await {
let event = event?;
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
}
LanguageModelCompletionEvent::Text(chunk) => {
text.push_str(&chunk);
}
LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
}
smol::future::yield_now().await;
}
anyhow::Ok((stop_reason, text))
};
let result = stream_completion.await;
this.update(&mut cx, |this, cx| {
if let Some((_stop_reason, text)) = result.log_err() {
this.thread.update(cx, |thread, _cx| {
thread.messages.push(thread::Message {
role: Role::Assistant,
text,
});
});
}
})
.ok();
});
self.thread.update(cx, |thread, _cx| {
thread.pending_completion_tasks.push(task);
self.thread.update(cx, |thread, cx| {
thread.stream_completion(request, model, cx)
});
None

View file

@ -1,5 +1,11 @@
use gpui::{ModelContext, Task};
use language_model::Role;
use std::sync::Arc;
use futures::StreamExt as _;
use gpui::{EventEmitter, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, Role, StopReason,
};
use util::ResultExt as _;
/// A message in a [`Thread`].
pub struct Message {
@ -20,4 +26,64 @@ impl Thread {
pending_completion_tasks: Vec::new(),
}
}
pub fn stream_completion(
&mut self,
request: LanguageModelRequest,
model: Arc<dyn LanguageModel>,
cx: &mut ModelContext<Self>,
) {
let task = cx.spawn(|this, mut cx| async move {
let stream = model.stream_completion(request, &cx);
let stream_completion = async {
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
while let Some(event) = events.next().await {
let event = event?;
this.update(&mut cx, |thread, cx| {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
thread.messages.push(Message {
role: Role::Assistant,
text: String::new(),
});
}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
}
LanguageModelCompletionEvent::Text(chunk) => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk);
}
}
}
LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
}
cx.emit(ThreadEvent::StreamedCompletion);
cx.notify();
})?;
smol::future::yield_now().await;
}
anyhow::Ok(stop_reason)
};
let result = stream_completion.await;
let _ = result.log_err();
});
self.pending_completion_tasks.push(task);
}
}
#[derive(Debug, Clone)]
pub enum ThreadEvent {
StreamedCompletion,
}
impl EventEmitter<ThreadEvent> for Thread {}