Incrementally diff input coming from GPT

This commit is contained in:
Antonio Scandurra 2023-08-21 15:11:06 +02:00
parent 3ad7f528cb
commit 42f02eb4e7
6 changed files with 315 additions and 137 deletions

View file

@ -1,7 +1,7 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings},
MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
@ -12,26 +12,23 @@ use editor::{
Anchor, Editor, ToOffset,
};
use fs::Fs;
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use futures::StreamExt;
use gpui::{
actions,
elements::*,
executor::Background,
geometry::vector::{vec2f, Vector2F},
platform::{CursorStyle, MouseButton},
Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
};
use isahc::{http::StatusCode, Request, RequestExt};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use search::BufferSearchBar;
use serde::Deserialize;
use settings::SettingsStore;
use std::{
cell::RefCell,
cmp, env,
fmt::Write,
io, iter,
iter,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
@ -46,8 +43,6 @@ use workspace::{
Save, ToggleZoom, Toolbar, Workspace,
};
const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
actions!(
assistant,
[
@ -2144,92 +2139,6 @@ impl Message {
}
}
async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;