diff --git a/Cargo.lock b/Cargo.lock index 58095343f1..103b9c2c2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -86,6 +86,20 @@ dependencies = [ "memchr", ] +[[package]] +name = "ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "ctor", + "futures 0.3.28", + "gpui", + "isahc", + "regex", + "serde", + "serde_json", +] + [[package]] name = "alacritty_config" version = "0.1.2-dev" @@ -272,6 +286,7 @@ checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" name = "assistant" version = "0.1.0" dependencies = [ + "ai", "anyhow", "chrono", "client", diff --git a/Cargo.toml b/Cargo.toml index 1c05c810f6..e86960daf4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "crates/activity_indicator", + "crates/ai", "crates/assistant", "crates/audio", "crates/auto_update", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml new file mode 100644 index 0000000000..c4e129c1f5 --- /dev/null +++ b/crates/ai/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "ai" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai.rs" +doctest = false + +[dependencies] +gpui = { path = "../gpui" } +anyhow.workspace = true +futures.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true + +[dev-dependencies] +ctor.workspace = true diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs new file mode 100644 index 0000000000..c893d109ab --- /dev/null +++ b/crates/ai/src/ai.rs @@ -0,0 +1 @@ +pub mod completion; diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs new file mode 100644 index 0000000000..170b2268f9 --- /dev/null +++ b/crates/ai/src/completion.rs @@ -0,0 +1,212 @@ +use anyhow::{anyhow, Result}; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui::executor::Background; +use isahc::{http::StatusCode, Request, RequestExt}; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{self, Display}, + io, + sync::Arc, +}; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + api_key: String, + executor: Arc, + mut request: OpenAIRequest, +) -> Result>> { + request.stream = true; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = serde_json::to_string(&request)?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +pub trait CompletionProvider { + fn complete( + &self, + prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>>; +} + +pub struct OpenAICompletionProvider { + api_key: String, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(api_key: String, executor: Arc) -> Self { + Self { api_key, executor } + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn complete( + &self, + prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>> { + let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } +} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index a3ee412548..5d141b32d5 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -9,6 +9,7 @@ path = "src/assistant.rs" doctest = false [dependencies] +ai = { path = "../ai" } client = { path = "../client" } collections = { path = "../collections"} editor = { path = "../editor" } diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 48e31bc55a..258684db47 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -3,37 +3,20 @@ mod assistant_settings; mod codegen; mod streaming_diff; -use anyhow::{anyhow, Result}; +use ai::completion::Role; +use anyhow::Result; pub use assistant_panel::AssistantPanel; use assistant_settings::OpenAIModel; use chrono::{DateTime, Local}; use collections::HashMap; use fs::Fs; -use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; -use gpui::{executor::Background, AppContext}; -use isahc::{http::StatusCode, Request, RequestExt}; +use futures::StreamExt; +use gpui::AppContext; use regex::Regex; use serde::{Deserialize, Serialize}; -use std::{ - cmp::Reverse, - ffi::OsStr, - fmt::{self, Display}, - io, - path::PathBuf, - sync::Arc, -}; +use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc}; use util::paths::CONVERSATIONS_DIR; -const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; - -// Data types for chat completion requests -#[derive(Debug, Default, Serialize)] -pub struct OpenAIRequest { - model: String, - messages: Vec, - stream: bool, -} - #[derive( Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize, )] @@ -116,175 +99,10 @@ impl SavedConversationMetadata { } } -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -struct RequestMessage { - role: Role, - content: String, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ResponseMessage { - role: Option, - content: Option, -} - -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -enum Role { - User, - Assistant, - System, -} - -impl Role { - pub fn cycle(&mut self) { - *self = match self { - Role::User => Role::Assistant, - Role::Assistant => Role::System, - Role::System => Role::User, - } - } -} - -impl Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "User"), - Role::Assistant => write!(f, "Assistant"), - Role::System => write!(f, "System"), - } - } -} - -#[derive(Deserialize, Debug)] -pub struct OpenAIResponseStreamEvent { - pub id: Option, - pub object: String, - pub created: u32, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -#[derive(Deserialize, Debug)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Deserialize, Debug)] -pub struct ChatChoiceDelta { - pub index: u32, - pub delta: ResponseMessage, - pub finish_reason: Option, -} - -#[derive(Deserialize, Debug)] -struct OpenAIUsage { - prompt_tokens: u64, - completion_tokens: u64, - total_tokens: u64, -} - -#[derive(Deserialize, Debug)] -struct OpenAIChoice { - text: String, - index: u32, - logprobs: Option, - finish_reason: Option, -} - pub fn init(cx: &mut AppContext) { assistant_panel::init(cx); } -pub async fn stream_completion( - api_key: String, - executor: Arc, - mut request: OpenAIRequest, -) -> Result>> { - request.stream = true; - - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - - let json_data = serde_json::to_string(&request)?; - let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(json_data)? - .send_async() - .await?; - - let status = response.status(); - if status == StatusCode::OK { - executor - .spawn(async move { - let mut lines = BufReader::new(response.body_mut()).lines(); - - fn parse_line( - line: Result, - ) -> Result> { - if let Some(data) = line?.strip_prefix("data: ") { - let event = serde_json::from_str(&data)?; - Ok(Some(event)) - } else { - Ok(None) - } - } - - while let Some(line) = lines.next().await { - if let Some(event) = parse_line(line).transpose() { - let done = event.as_ref().map_or(false, |event| { - event - .choices - .last() - .map_or(false, |choice| choice.finish_reason.is_some()) - }); - if tx.unbounded_send(event).is_err() { - break; - } - - if done { - break; - } - } - } - - anyhow::Ok(()) - }) - .detach(); - - Ok(rx) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAIResponse { - error: OpenAIError, - } - - #[derive(Deserialize)] - struct OpenAIError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to OpenAI API: {}", - response.error.message, - )), - - _ => Err(anyhow!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - )), - } - } -} - #[cfg(test)] #[ctor::ctor] fn init_logger() { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 263382c03e..42e5fb7897 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,8 +1,11 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, - codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider}, - stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, - Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, + codegen::{self, Codegen, CodegenKind}, + MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, + SavedMessage, +}; +use ai::completion::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e7da46cdf9..e956d72260 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,59 +1,14 @@ -use crate::{ - stream_completion, - streaming_diff::{Hunk, StreamingDiff}, - OpenAIRequest, -}; +use crate::streaming_diff::{Hunk, StreamingDiff}; +use ai::completion::{CompletionProvider, OpenAIRequest}; use anyhow::Result; use editor::{ multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, }; -use futures::{ - channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt, -}; -use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task}; +use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; +use gpui::{Entity, ModelContext, ModelHandle, Task}; use language::{Rope, TransactionId}; use std::{cmp, future, ops::Range, sync::Arc}; -pub trait CompletionProvider { - fn complete( - &self, - prompt: OpenAIRequest, - ) -> BoxFuture<'static, Result>>>; -} - -pub struct OpenAICompletionProvider { - api_key: String, - executor: Arc, -} - -impl OpenAICompletionProvider { - pub fn new(api_key: String, executor: Arc) -> Self { - Self { api_key, executor } - } -} - -impl CompletionProvider for OpenAICompletionProvider { - fn complete( - &self, - prompt: OpenAIRequest, - ) -> BoxFuture<'static, Result>>> { - let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); - async move { - let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) - } - .boxed() - } -} - pub enum Event { Finished, Undone, @@ -397,13 +352,17 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use futures::stream; + use futures::{ + future::BoxFuture, + stream::{self, BoxStream}, + }; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; use parking_lot::Mutex; use rand::prelude::*; use settings::SettingsStore; + use smol::future::FutureExt; #[gpui::test(iterations = 10)] async fn test_transform_autoindent( diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index dcdbd004b7..bdf060205a 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -5,9 +5,9 @@ pub mod only_instance; #[cfg(any(test, feature = "test-support"))] pub mod test; -use assistant::AssistantPanel; use anyhow::Context; use assets::Assets; +use assistant::AssistantPanel; use breadcrumbs::Breadcrumbs; pub use client; use collab_ui::CollabTitlebarItem; // TODO: Add back toggle collab ui shortcut @@ -2418,7 +2418,7 @@ mod tests { pane::init(cx); project_panel::init((), cx); terminal_view::init(cx); - ai::init(cx); + assistant::init(cx); app_state }) }