introduce ai crate with completion providers

This commit is contained in:
KCaverly 2023-09-21 22:44:56 -04:00
parent 5f6334696a
commit 48e151495f
10 changed files with 273 additions and 242 deletions

15
Cargo.lock generated
View file

@ -86,6 +86,20 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "ai"
version = "0.1.0"
dependencies = [
"anyhow",
"ctor",
"futures 0.3.28",
"gpui",
"isahc",
"regex",
"serde",
"serde_json",
]
[[package]] [[package]]
name = "alacritty_config" name = "alacritty_config"
version = "0.1.2-dev" version = "0.1.2-dev"
@ -272,6 +286,7 @@ checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16"
name = "assistant" name = "assistant"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"ai",
"anyhow", "anyhow",
"chrono", "chrono",
"client", "client",

View file

@ -1,6 +1,7 @@
[workspace] [workspace]
members = [ members = [
"crates/activity_indicator", "crates/activity_indicator",
"crates/ai",
"crates/assistant", "crates/assistant",
"crates/audio", "crates/audio",
"crates/auto_update", "crates/auto_update",

21
crates/ai/Cargo.toml Normal file
View file

@ -0,0 +1,21 @@
[package]
name = "ai"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/ai.rs"
doctest = false
[dependencies]
gpui = { path = "../gpui" }
anyhow.workspace = true
futures.workspace = true
isahc.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
[dev-dependencies]
ctor.workspace = true

1
crates/ai/src/ai.rs Normal file
View file

@ -0,0 +1 @@
pub mod completion;

212
crates/ai/src/completion.rs Normal file
View file

@ -0,0 +1,212 @@
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::executor::Background;
use isahc::{http::StatusCode, Request, RequestExt};
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Display},
io,
sync::Arc,
};
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
pub trait CompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
}
pub struct OpenAICompletionProvider {
api_key: String,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
Self { api_key, executor }
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
}

View file

@ -9,6 +9,7 @@ path = "src/assistant.rs"
doctest = false doctest = false
[dependencies] [dependencies]
ai = { path = "../ai" }
client = { path = "../client" } client = { path = "../client" }
collections = { path = "../collections"} collections = { path = "../collections"}
editor = { path = "../editor" } editor = { path = "../editor" }

View file

@ -3,37 +3,20 @@ mod assistant_settings;
mod codegen; mod codegen;
mod streaming_diff; mod streaming_diff;
use anyhow::{anyhow, Result}; use ai::completion::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel; pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel; use assistant_settings::OpenAIModel;
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
use collections::HashMap; use collections::HashMap;
use fs::Fs; use fs::Fs;
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use futures::StreamExt;
use gpui::{executor::Background, AppContext}; use gpui::AppContext;
use isahc::{http::StatusCode, Request, RequestExt};
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
cmp::Reverse,
ffi::OsStr,
fmt::{self, Display},
io,
path::PathBuf,
sync::Arc,
};
use util::paths::CONVERSATIONS_DIR; use util::paths::CONVERSATIONS_DIR;
const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
// Data types for chat completion requests
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
model: String,
messages: Vec<RequestMessage>,
stream: bool,
}
#[derive( #[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize, Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)] )]
@ -116,175 +99,10 @@ impl SavedConversationMetadata {
} }
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct RequestMessage {
role: Role,
content: String,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
role: Option<Role>,
content: Option<String>,
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
struct OpenAIUsage {
prompt_tokens: u64,
completion_tokens: u64,
total_tokens: u64,
}
#[derive(Deserialize, Debug)]
struct OpenAIChoice {
text: String,
index: u32,
logprobs: Option<serde_json::Value>,
finish_reason: Option<String>,
}
pub fn init(cx: &mut AppContext) { pub fn init(cx: &mut AppContext) {
assistant_panel::init(cx); assistant_panel::init(cx);
} }
pub async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[cfg(test)] #[cfg(test)]
#[ctor::ctor] #[ctor::ctor]
fn init_logger() { fn init_logger() {

View file

@ -1,8 +1,11 @@
use crate::{ use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider}, codegen::{self, Codegen, CodegenKind},
stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, SavedMessage,
};
use ai::completion::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};

View file

@ -1,59 +1,14 @@
use crate::{ use crate::streaming_diff::{Hunk, StreamingDiff};
stream_completion, use ai::completion::{CompletionProvider, OpenAIRequest};
streaming_diff::{Hunk, StreamingDiff},
OpenAIRequest,
};
use anyhow::Result; use anyhow::Result;
use editor::{ use editor::{
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
}; };
use futures::{ use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt, use gpui::{Entity, ModelContext, ModelHandle, Task};
};
use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
use language::{Rope, TransactionId}; use language::{Rope, TransactionId};
use std::{cmp, future, ops::Range, sync::Arc}; use std::{cmp, future, ops::Range, sync::Arc};
pub trait CompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
}
pub struct OpenAICompletionProvider {
api_key: String,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
Self { api_key, executor }
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
}
pub enum Event { pub enum Event {
Finished, Finished,
Undone, Undone,
@ -397,13 +352,17 @@ fn strip_markdown_codeblock(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures::stream; use futures::{
future::BoxFuture,
stream::{self, BoxStream},
};
use gpui::{executor::Deterministic, TestAppContext}; use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc; use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex; use parking_lot::Mutex;
use rand::prelude::*; use rand::prelude::*;
use settings::SettingsStore; use settings::SettingsStore;
use smol::future::FutureExt;
#[gpui::test(iterations = 10)] #[gpui::test(iterations = 10)]
async fn test_transform_autoindent( async fn test_transform_autoindent(

View file

@ -5,9 +5,9 @@ pub mod only_instance;
#[cfg(any(test, feature = "test-support"))] #[cfg(any(test, feature = "test-support"))]
pub mod test; pub mod test;
use assistant::AssistantPanel;
use anyhow::Context; use anyhow::Context;
use assets::Assets; use assets::Assets;
use assistant::AssistantPanel;
use breadcrumbs::Breadcrumbs; use breadcrumbs::Breadcrumbs;
pub use client; pub use client;
use collab_ui::CollabTitlebarItem; // TODO: Add back toggle collab ui shortcut use collab_ui::CollabTitlebarItem; // TODO: Add back toggle collab ui shortcut
@ -2418,7 +2418,7 @@ mod tests {
pane::init(cx); pane::init(cx);
project_panel::init((), cx); project_panel::init((), cx);
terminal_view::init(cx); terminal_view::init(cx);
ai::init(cx); assistant::init(cx);
app_state app_state
}) })
} }