WIP: Stream in completions
Drop dependency on tokio introduced by async-openai and do it ourselves. The approach I'm taking of replacing instead of appending is causing issues. Need to just append.
This commit is contained in:
parent
912fd23006
commit
7e6cccfa3d
9 changed files with 209 additions and 236 deletions
204
Cargo.lock
generated
204
Cargo.lock
generated
|
@ -100,11 +100,16 @@ name = "ai"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-openai",
|
"async-stream",
|
||||||
"editor",
|
"editor",
|
||||||
|
"futures 0.3.28",
|
||||||
"gpui",
|
"gpui",
|
||||||
|
"isahc",
|
||||||
"pulldown-cmark",
|
"pulldown-cmark",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"unindent",
|
"unindent",
|
||||||
|
"util",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -354,28 +359,6 @@ dependencies = [
|
||||||
"futures-lite",
|
"futures-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "async-openai"
|
|
||||||
version = "0.10.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e5d5e93aca1b2f0ca772c76cadd43e965809df87ef98e25e47244c7f006c85d2"
|
|
||||||
dependencies = [
|
|
||||||
"backoff",
|
|
||||||
"base64 0.21.0",
|
|
||||||
"derive_builder",
|
|
||||||
"futures 0.3.28",
|
|
||||||
"rand 0.8.5",
|
|
||||||
"reqwest",
|
|
||||||
"reqwest-eventsource",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"thiserror",
|
|
||||||
"tokio",
|
|
||||||
"tokio-stream",
|
|
||||||
"tokio-util 0.7.8",
|
|
||||||
"tracing",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-pipe"
|
name = "async-pipe"
|
||||||
version = "0.1.3"
|
version = "0.1.3"
|
||||||
|
@ -676,20 +659,6 @@ dependencies = [
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "backoff"
|
|
||||||
version = "0.4.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
|
|
||||||
dependencies = [
|
|
||||||
"futures-core",
|
|
||||||
"getrandom 0.2.9",
|
|
||||||
"instant",
|
|
||||||
"pin-project-lite 0.2.9",
|
|
||||||
"rand 0.8.5",
|
|
||||||
"tokio",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backtrace"
|
name = "backtrace"
|
||||||
version = "0.3.67"
|
version = "0.3.67"
|
||||||
|
@ -1849,41 +1818,6 @@ dependencies = [
|
||||||
"syn 2.0.15",
|
"syn 2.0.15",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "darling"
|
|
||||||
version = "0.14.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
|
|
||||||
dependencies = [
|
|
||||||
"darling_core",
|
|
||||||
"darling_macro",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "darling_core"
|
|
||||||
version = "0.14.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
|
|
||||||
dependencies = [
|
|
||||||
"fnv",
|
|
||||||
"ident_case",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"strsim 0.10.0",
|
|
||||||
"syn 1.0.109",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "darling_macro"
|
|
||||||
version = "0.14.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
|
|
||||||
dependencies = [
|
|
||||||
"darling_core",
|
|
||||||
"quote",
|
|
||||||
"syn 1.0.109",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dashmap"
|
name = "dashmap"
|
||||||
version = "5.4.0"
|
version = "5.4.0"
|
||||||
|
@ -1938,37 +1872,6 @@ dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "derive_builder"
|
|
||||||
version = "0.12.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
|
|
||||||
dependencies = [
|
|
||||||
"derive_builder_macro",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "derive_builder_core"
|
|
||||||
version = "0.12.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
|
|
||||||
dependencies = [
|
|
||||||
"darling",
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn 1.0.109",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "derive_builder_macro"
|
|
||||||
version = "0.12.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
|
|
||||||
dependencies = [
|
|
||||||
"derive_builder_core",
|
|
||||||
"syn 1.0.109",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dhat"
|
name = "dhat"
|
||||||
version = "0.3.2"
|
version = "0.3.2"
|
||||||
|
@ -2304,17 +2207,6 @@ version = "2.5.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
|
checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "eventsource-stream"
|
|
||||||
version = "0.2.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
|
|
||||||
dependencies = [
|
|
||||||
"futures-core",
|
|
||||||
"nom",
|
|
||||||
"pin-project-lite 0.2.9",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fallible-iterator"
|
name = "fallible-iterator"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
|
@ -2711,12 +2603,6 @@ version = "0.3.28"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
|
checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "futures-timer"
|
|
||||||
version = "3.0.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-util"
|
name = "futures-util"
|
||||||
version = "0.3.28"
|
version = "0.3.28"
|
||||||
|
@ -3200,19 +3086,6 @@ dependencies = [
|
||||||
"want",
|
"want",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "hyper-rustls"
|
|
||||||
version = "0.23.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c"
|
|
||||||
dependencies = [
|
|
||||||
"http",
|
|
||||||
"hyper",
|
|
||||||
"rustls 0.20.8",
|
|
||||||
"tokio",
|
|
||||||
"tokio-rustls",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hyper-timeout"
|
name = "hyper-timeout"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
|
@ -3262,12 +3135,6 @@ dependencies = [
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ident_case"
|
|
||||||
version = "1.0.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
|
@ -4062,16 +3929,6 @@ version = "0.3.17"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "mime_guess"
|
|
||||||
version = "2.0.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
|
|
||||||
dependencies = [
|
|
||||||
"mime",
|
|
||||||
"unicase",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minimal-lexical"
|
name = "minimal-lexical"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
@ -5537,52 +5394,28 @@ dependencies = [
|
||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-rustls",
|
|
||||||
"hyper-tls",
|
"hyper-tls",
|
||||||
"ipnet",
|
"ipnet",
|
||||||
"js-sys",
|
"js-sys",
|
||||||
"log",
|
"log",
|
||||||
"mime",
|
"mime",
|
||||||
"mime_guess",
|
|
||||||
"native-tls",
|
"native-tls",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite 0.2.9",
|
"pin-project-lite 0.2.9",
|
||||||
"rustls 0.20.8",
|
|
||||||
"rustls-native-certs",
|
|
||||||
"rustls-pemfile",
|
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-native-tls",
|
"tokio-native-tls",
|
||||||
"tokio-rustls",
|
|
||||||
"tokio-util 0.7.8",
|
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"url",
|
"url",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"wasm-bindgen-futures",
|
"wasm-bindgen-futures",
|
||||||
"wasm-streams",
|
|
||||||
"web-sys",
|
"web-sys",
|
||||||
"winreg",
|
"winreg",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "reqwest-eventsource"
|
|
||||||
version = "0.4.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51"
|
|
||||||
dependencies = [
|
|
||||||
"eventsource-stream",
|
|
||||||
"futures-core",
|
|
||||||
"futures-timer",
|
|
||||||
"mime",
|
|
||||||
"nom",
|
|
||||||
"pin-project-lite 0.2.9",
|
|
||||||
"reqwest",
|
|
||||||
"thiserror",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "resvg"
|
name = "resvg"
|
||||||
version = "0.14.1"
|
version = "0.14.1"
|
||||||
|
@ -5870,18 +5703,6 @@ dependencies = [
|
||||||
"webpki 0.22.0",
|
"webpki 0.22.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rustls-native-certs"
|
|
||||||
version = "0.6.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50"
|
|
||||||
dependencies = [
|
|
||||||
"openssl-probe",
|
|
||||||
"rustls-pemfile",
|
|
||||||
"schannel",
|
|
||||||
"security-framework",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-pemfile"
|
name = "rustls-pemfile"
|
||||||
version = "1.0.2"
|
version = "1.0.2"
|
||||||
|
@ -8245,19 +8066,6 @@ dependencies = [
|
||||||
"leb128",
|
"leb128",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "wasm-streams"
|
|
||||||
version = "0.2.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078"
|
|
||||||
dependencies = [
|
|
||||||
"futures-util",
|
|
||||||
"js-sys",
|
|
||||||
"wasm-bindgen",
|
|
||||||
"wasm-bindgen-futures",
|
|
||||||
"web-sys",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasmparser"
|
name = "wasmparser"
|
||||||
version = "0.85.0"
|
version = "0.85.0"
|
||||||
|
|
|
@ -79,6 +79,7 @@ ctor = { version = "0.1" }
|
||||||
env_logger = { version = "0.9" }
|
env_logger = { version = "0.9" }
|
||||||
futures = { version = "0.3" }
|
futures = { version = "0.3" }
|
||||||
glob = { version = "0.3.1" }
|
glob = { version = "0.3.1" }
|
||||||
|
isahc = "1.7.2"
|
||||||
lazy_static = { version = "1.4.0" }
|
lazy_static = { version = "1.4.0" }
|
||||||
log = { version = "0.4.16", features = ["kv_unstable_serde"] }
|
log = { version = "0.4.16", features = ["kv_unstable_serde"] }
|
||||||
ordered-float = { version = "2.1.1" }
|
ordered-float = { version = "2.1.1" }
|
||||||
|
|
|
@ -11,11 +11,16 @@ doctest = false
|
||||||
[dependencies]
|
[dependencies]
|
||||||
editor = { path = "../editor" }
|
editor = { path = "../editor" }
|
||||||
gpui = { path = "../gpui" }
|
gpui = { path = "../gpui" }
|
||||||
|
util = { path = "../util" }
|
||||||
|
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
async-openai = "0.10.3"
|
|
||||||
pulldown-cmark = "0.9.2"
|
pulldown-cmark = "0.9.2"
|
||||||
|
futures.workspace = true
|
||||||
|
isahc.workspace = true
|
||||||
unindent.workspace = true
|
unindent.workspace = true
|
||||||
|
async-stream = "0.3.5"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
editor = { path = "../editor", features = ["test-support"] }
|
editor = { path = "../editor", features = ["test-support"] }
|
||||||
|
|
|
@ -1,11 +1,87 @@
|
||||||
use anyhow::Result;
|
use std::io;
|
||||||
use async_openai::types::{ChatCompletionRequestMessage, CreateChatCompletionRequest, Role};
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
use editor::Editor;
|
use editor::Editor;
|
||||||
|
use futures::AsyncBufReadExt;
|
||||||
|
use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
|
||||||
|
use gpui::executor::Foreground;
|
||||||
use gpui::{actions, AppContext, Task, ViewContext};
|
use gpui::{actions, AppContext, Task, ViewContext};
|
||||||
|
use isahc::prelude::*;
|
||||||
|
use isahc::{http::StatusCode, Request};
|
||||||
use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
|
use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
actions!(ai, [Assist]);
|
actions!(ai, [Assist]);
|
||||||
|
|
||||||
|
// Data types for chat completion requests
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct OpenAIRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<RequestMessage>,
|
||||||
|
stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
struct RequestMessage {
|
||||||
|
role: Role,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
struct ResponseMessage {
|
||||||
|
role: Option<Role>,
|
||||||
|
content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
enum Role {
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
System,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct OpenAIResponseStreamEvent {
|
||||||
|
pub id: Option<String>,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u32,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatChoiceDelta>,
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct ChatChoiceDelta {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: ResponseMessage,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct OpenAIUsage {
|
||||||
|
prompt_tokens: u64,
|
||||||
|
completion_tokens: u64,
|
||||||
|
total_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct OpenAIChoice {
|
||||||
|
text: String,
|
||||||
|
index: u32,
|
||||||
|
logprobs: Option<serde_json::Value>,
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
pub fn init(cx: &mut AppContext) {
|
pub fn init(cx: &mut AppContext) {
|
||||||
cx.add_async_action(assist)
|
cx.add_async_action(assist)
|
||||||
}
|
}
|
||||||
|
@ -15,26 +91,58 @@ fn assist(
|
||||||
_: &Assist,
|
_: &Assist,
|
||||||
cx: &mut ViewContext<Editor>,
|
cx: &mut ViewContext<Editor>,
|
||||||
) -> Option<Task<Result<()>>> {
|
) -> Option<Task<Result<()>>> {
|
||||||
|
let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
|
||||||
|
|
||||||
let markdown = editor.text(cx);
|
let markdown = editor.text(cx);
|
||||||
parse_dialog(&markdown);
|
let prompt = parse_dialog(&markdown);
|
||||||
None
|
let response = stream_completion(api_key, prompt, cx.foreground().clone());
|
||||||
|
|
||||||
|
let range = editor.buffer().update(cx, |buffer, cx| {
|
||||||
|
let snapshot = buffer.snapshot(cx);
|
||||||
|
let chars = snapshot.reversed_chars_at(snapshot.len());
|
||||||
|
let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count();
|
||||||
|
let suffix = "\n".repeat(2 - trailing_newlines);
|
||||||
|
let end = snapshot.len();
|
||||||
|
buffer.edit([(end..end, suffix.clone())], None, cx);
|
||||||
|
let snapshot = buffer.snapshot(cx);
|
||||||
|
let start = snapshot.anchor_before(snapshot.len());
|
||||||
|
let end = snapshot.anchor_after(snapshot.len());
|
||||||
|
start..end
|
||||||
|
});
|
||||||
|
let buffer = editor.buffer().clone();
|
||||||
|
|
||||||
|
Some(cx.spawn(|_, mut cx| async move {
|
||||||
|
let mut stream = response.await?;
|
||||||
|
let mut message = String::new();
|
||||||
|
while let Some(stream_event) = stream.next().await {
|
||||||
|
if let Some(choice) = stream_event?.choices.first() {
|
||||||
|
if let Some(content) = &choice.delta.content {
|
||||||
|
message.push_str(content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.update(&mut cx, |buffer, cx| {
|
||||||
|
buffer.edit([(range.clone(), message.clone())], None, cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_dialog(markdown: &str) -> CreateChatCompletionRequest {
|
fn parse_dialog(markdown: &str) -> OpenAIRequest {
|
||||||
let parser = Parser::new(markdown);
|
let parser = Parser::new(markdown);
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
|
|
||||||
let mut current_role: Option<(Role, Option<String>)> = None;
|
let mut current_role: Option<Role> = None;
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
for event in parser {
|
for event in parser {
|
||||||
match event {
|
match event {
|
||||||
Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
|
Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
|
||||||
if let Some((role, name)) = current_role.take() {
|
if let Some(role) = current_role.take() {
|
||||||
if !buffer.is_empty() {
|
if !buffer.is_empty() {
|
||||||
messages.push(ChatCompletionRequestMessage {
|
messages.push(RequestMessage {
|
||||||
role,
|
role,
|
||||||
content: buffer.trim().to_string(),
|
content: buffer.trim().to_string(),
|
||||||
name,
|
|
||||||
});
|
});
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
}
|
}
|
||||||
|
@ -45,36 +153,89 @@ fn parse_dialog(markdown: &str) -> CreateChatCompletionRequest {
|
||||||
buffer.push_str(&text);
|
buffer.push_str(&text);
|
||||||
} else {
|
} else {
|
||||||
// Determine the current role based on the H2 header text
|
// Determine the current role based on the H2 header text
|
||||||
let mut chars = text.chars();
|
let text = text.to_lowercase();
|
||||||
let first_char = chars.by_ref().skip_while(|c| c.is_whitespace()).next();
|
current_role = if text.contains("user") {
|
||||||
let name = chars.take_while(|c| *c != '\n').collect::<String>();
|
Some(Role::User)
|
||||||
let name = if name.is_empty() { None } else { Some(name) };
|
} else if text.contains("assistant") {
|
||||||
|
Some(Role::Assistant)
|
||||||
let role = match first_char {
|
} else if text.contains("system") {
|
||||||
Some('@') => Some(Role::User),
|
Some(Role::System)
|
||||||
Some('/') => Some(Role::Assistant),
|
} else {
|
||||||
Some('#') => Some(Role::System),
|
None
|
||||||
_ => None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
current_role = role.map(|role| (role, name));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => (),
|
_ => (),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some((role, name)) = current_role {
|
if let Some(role) = current_role {
|
||||||
messages.push(ChatCompletionRequestMessage {
|
messages.push(RequestMessage {
|
||||||
role,
|
role,
|
||||||
content: buffer,
|
content: buffer,
|
||||||
name,
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
CreateChatCompletionRequest {
|
OpenAIRequest {
|
||||||
model: "gpt-4".into(),
|
model: "gpt-4".into(),
|
||||||
messages,
|
messages,
|
||||||
..Default::default()
|
stream: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream_completion(
|
||||||
|
api_key: String,
|
||||||
|
mut request: OpenAIRequest,
|
||||||
|
executor: Rc<Foreground>,
|
||||||
|
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||||
|
request.stream = true;
|
||||||
|
|
||||||
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||||
|
|
||||||
|
let json_data = serde_json::to_string(&request)?;
|
||||||
|
let mut response = Request::post("https://api.openai.com/v1/chat/completions")
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
.body(json_data)?
|
||||||
|
.send_async()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if status == StatusCode::OK {
|
||||||
|
executor
|
||||||
|
.spawn(async move {
|
||||||
|
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||||
|
|
||||||
|
fn parse_line(
|
||||||
|
line: Result<String, io::Error>,
|
||||||
|
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||||
|
if let Some(data) = line?.strip_prefix("data: ") {
|
||||||
|
let event = serde_json::from_str(&data)?;
|
||||||
|
Ok(Some(event))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some(line) = lines.next().await {
|
||||||
|
if let Some(event) = parse_line(line).transpose() {
|
||||||
|
tx.unbounded_send(event).log_err();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
} else {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
Err(anyhow!(
|
||||||
|
"Failed to connect to OpenAI API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,23 +248,21 @@ mod tests {
|
||||||
use unindent::Unindent;
|
use unindent::Unindent;
|
||||||
|
|
||||||
let test_input = r#"
|
let test_input = r#"
|
||||||
## @nathan
|
## System
|
||||||
Hey there, welcome to Zed!
|
Hey there, welcome to Zed!
|
||||||
|
|
||||||
## /sky
|
## Assintant
|
||||||
Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
|
Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
|
||||||
"#.unindent();
|
"#.unindent();
|
||||||
|
|
||||||
let expected_output = vec![
|
let expected_output = vec![
|
||||||
ChatCompletionRequestMessage {
|
RequestMessage {
|
||||||
role: Role::User,
|
role: Role::User,
|
||||||
content: "Hey there, welcome to Zed!".to_string(),
|
content: "Hey there, welcome to Zed!".to_string(),
|
||||||
name: Some("nathan".to_string()),
|
|
||||||
},
|
},
|
||||||
ChatCompletionRequestMessage {
|
RequestMessage {
|
||||||
role: Role::Assistant,
|
role: Role::Assistant,
|
||||||
content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
|
content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
|
||||||
name: Some("sky".to_string()),
|
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ theme = { path = "../theme" }
|
||||||
workspace = { path = "../workspace" }
|
workspace = { path = "../workspace" }
|
||||||
util = { path = "../util" }
|
util = { path = "../util" }
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
isahc = "1.7"
|
isahc.workspace = true
|
||||||
lazy_static.workspace = true
|
lazy_static.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
|
|
|
@ -27,7 +27,7 @@ futures.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
smallvec.workspace = true
|
smallvec.workspace = true
|
||||||
human_bytes = "0.4.1"
|
human_bytes = "0.4.1"
|
||||||
isahc = "1.7"
|
isahc.workspace = true
|
||||||
lazy_static.workspace = true
|
lazy_static.workspace = true
|
||||||
postage.workspace = true
|
postage.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
|
|
|
@ -960,7 +960,7 @@ impl<T: 'static, E: 'static + Display> Task<Result<T, E>> {
|
||||||
pub fn detach_and_log_err(self, cx: &mut AppContext) {
|
pub fn detach_and_log_err(self, cx: &mut AppContext) {
|
||||||
cx.spawn(|_| async move {
|
cx.spawn(|_| async move {
|
||||||
if let Err(err) = self.await {
|
if let Err(err) = self.await {
|
||||||
log::error!("{}", err);
|
log::error!("{:#}", err);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
|
|
|
@ -17,7 +17,7 @@ backtrace = "0.3"
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
lazy_static.workspace = true
|
lazy_static.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
isahc = "1.7"
|
isahc.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
url = "2.2"
|
url = "2.2"
|
||||||
rand.workspace = true
|
rand.workspace = true
|
||||||
|
|
|
@ -82,7 +82,7 @@ futures.workspace = true
|
||||||
ignore = "0.4"
|
ignore = "0.4"
|
||||||
image = "0.23"
|
image = "0.23"
|
||||||
indexmap = "1.6.2"
|
indexmap = "1.6.2"
|
||||||
isahc = "1.7"
|
isahc.workspace = true
|
||||||
lazy_static.workspace = true
|
lazy_static.workspace = true
|
||||||
libc = "0.2"
|
libc = "0.2"
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue