diff --git a/Cargo.lock b/Cargo.lock index d339ac3256..59f933f4b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -95,6 +95,18 @@ dependencies = [ "memchr", ] +[[package]] +name = "ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-openai", + "editor", + "gpui", + "pulldown-cmark", + "unindent", +] + [[package]] name = "alacritty_config" version = "0.1.1-dev" @@ -342,6 +354,28 @@ dependencies = [ "futures-lite", ] +[[package]] +name = "async-openai" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d5e93aca1b2f0ca772c76cadd43e965809df87ef98e25e47244c7f006c85d2" +dependencies = [ + "backoff", + "base64 0.21.0", + "derive_builder", + "futures 0.3.28", + "rand 0.8.5", + "reqwest", + "reqwest-eventsource", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util 0.7.8", + "tracing", +] + [[package]] name = "async-pipe" version = "0.1.3" @@ -642,6 +676,20 @@ dependencies = [ "tower-service", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.9", + "instant", + "pin-project-lite 0.2.9", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.67" @@ -1801,6 +1849,41 @@ dependencies = [ "syn 2.0.15", ] +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core", + "quote", + "syn 1.0.109", +] + [[package]] name = "dashmap" version = "5.4.0" @@ -1855,6 +1938,37 @@ dependencies = [ "byteorder", ] +[[package]] +name = "derive_builder" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder_macro" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +dependencies = [ + "derive_builder_core", + "syn 1.0.109", +] + [[package]] name = "dhat" version = "0.3.2" @@ -2190,6 +2304,17 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite 0.2.9", +] + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -2586,6 +2711,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.28" @@ -2633,6 +2764,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "getopts" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5" +dependencies = [ + "unicode-width", +] + [[package]] name = "getrandom" version = "0.1.16" @@ -3060,6 +3200,19 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c" +dependencies = [ + "http", + "hyper", + "rustls 0.20.8", + "tokio", + "tokio-rustls", +] + [[package]] name = "hyper-timeout" version = "0.4.1" @@ -3109,6 +3262,12 @@ dependencies = [ "cxx-build", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.3.0" @@ -3903,6 +4062,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -5071,6 +5240,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d9cc634bc78768157b5cbfe988ffcd1dcba95cd2b2f03a88316c08c6d00ed63" dependencies = [ "bitflags", + "getopts", "memchr", "unicase", ] @@ -5367,28 +5537,52 @@ dependencies = [ "http", "http-body", "hyper", + "hyper-rustls", "hyper-tls", "ipnet", "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", "pin-project-lite 0.2.9", + "rustls 0.20.8", + "rustls-native-certs", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", "tokio", "tokio-native-tls", + "tokio-rustls", + "tokio-util 0.7.8", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "winreg", ] +[[package]] +name = "reqwest-eventsource" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite 0.2.9", + "reqwest", + "thiserror", +] + [[package]] name = "resvg" version = "0.14.1" @@ -5676,6 +5870,18 @@ dependencies = [ "webpki 0.22.0", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.2" @@ -8039,6 +8245,19 @@ dependencies = [ "leb128", ] +[[package]] +name = "wasm-streams" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmparser" version = "0.85.0" @@ -8759,6 +8978,7 @@ name = "zed" version = "0.88.0" dependencies = [ "activity_indicator", + "ai", "anyhow", "assets", "async-compression", diff --git a/Cargo.toml b/Cargo.toml index f14e1c7355..77252802d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "crates/activity_indicator", + "crates/ai", "crates/assets", "crates/auto_update", "crates/breadcrumbs", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml new file mode 100644 index 0000000000..30dc5ee5a2 --- /dev/null +++ b/crates/ai/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "ai" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai.rs" +doctest = false + +[dependencies] +editor = { path = "../editor" } +gpui = { path = "../gpui" } + +anyhow.workspace = true +async-openai = "0.10.3" +pulldown-cmark = "0.9.2" +unindent.workspace = true + +[dev-dependencies] +editor = { path = "../editor", features = ["test-support"] } diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs new file mode 100644 index 0000000000..0ae960e281 --- /dev/null +++ b/crates/ai/src/ai.rs @@ -0,0 +1,112 @@ +use anyhow::Result; +use async_openai::types::{ChatCompletionRequestMessage, CreateChatCompletionRequest, Role}; +use editor::Editor; +use gpui::{actions, AppContext, Task, ViewContext}; +use pulldown_cmark::{Event, HeadingLevel, Parser, Tag}; + +actions!(ai, [Assist]); + +pub fn init(cx: &mut AppContext) { + cx.add_async_action(assist) +} + +fn assist( + editor: &mut Editor, + _: &Assist, + cx: &mut ViewContext, +) -> Option>> { + let markdown = editor.text(cx); + parse_dialog(&markdown); + None +} + +fn parse_dialog(markdown: &str) -> CreateChatCompletionRequest { + let parser = Parser::new(markdown); + let mut messages = Vec::new(); + + let mut current_role: Option<(Role, Option)> = None; + let mut buffer = String::new(); + for event in parser { + match event { + Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => { + if let Some((role, name)) = current_role.take() { + if !buffer.is_empty() { + messages.push(ChatCompletionRequestMessage { + role, + content: buffer.trim().to_string(), + name, + }); + buffer.clear(); + } + } + } + Event::Text(text) => { + if current_role.is_some() { + buffer.push_str(&text); + } else { + // Determine the current role based on the H2 header text + let mut chars = text.chars(); + let first_char = chars.by_ref().skip_while(|c| c.is_whitespace()).next(); + let name = chars.take_while(|c| *c != '\n').collect::(); + let name = if name.is_empty() { None } else { Some(name) }; + + let role = match first_char { + Some('@') => Some(Role::User), + Some('/') => Some(Role::Assistant), + Some('#') => Some(Role::System), + _ => None, + }; + + current_role = role.map(|role| (role, name)); + } + } + _ => (), + } + } + if let Some((role, name)) = current_role { + messages.push(ChatCompletionRequestMessage { + role, + content: buffer, + name, + }); + } + + CreateChatCompletionRequest { + model: "gpt-4".into(), + messages, + ..Default::default() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_dialog() { + use unindent::Unindent; + + let test_input = r#" + ## @nathan + Hey there, welcome to Zed! + + ## /sky + Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast. + "#.unindent(); + + let expected_output = vec![ + ChatCompletionRequestMessage { + role: Role::User, + content: "Hey there, welcome to Zed!".to_string(), + name: Some("nathan".to_string()), + }, + ChatCompletionRequestMessage { + role: Role::Assistant, + content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(), + name: Some("sky".to_string()), + }, + ]; + + assert_eq!(parse_dialog(&test_input).messages, expected_output); + } +} diff --git a/crates/live_kit_client/Cargo.toml b/crates/live_kit_client/Cargo.toml index 2d61e75732..36087a42a3 100644 --- a/crates/live_kit_client/Cargo.toml +++ b/crates/live_kit_client/Cargo.toml @@ -46,6 +46,7 @@ collections = { path = "../collections", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } live_kit_server = { path = "../live_kit_server" } media = { path = "../media" } +nanoid = "0.4" anyhow.workspace = true async-trait.workspace = true diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 90dced65f5..e24b7ef232 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -48,6 +48,7 @@ language_selector = { path = "../language_selector" } lsp = { path = "../lsp" } lsp_log = { path = "../lsp_log" } node_runtime = { path = "../node_runtime" } +ai = { path = "../ai" } outline = { path = "../outline" } plugin_runtime = { path = "../plugin_runtime" } project = { path = "../project" } diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 2f359240bc..eb2d693700 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -162,6 +162,7 @@ fn main() { terminal_view::init(cx); theme_testbench::init(cx); copilot::init(http.clone(), node_runtime, cx); + ai::init(cx); cx.spawn(|cx| watch_themes(fs.clone(), cx)).detach();