introduce ai crate with completion providers
This commit is contained in:
parent
5f6334696a
commit
48e151495f
10 changed files with 273 additions and 242 deletions
15
Cargo.lock
generated
15
Cargo.lock
generated
|
@ -86,6 +86,20 @@ dependencies = [
|
||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ai"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"ctor",
|
||||||
|
"futures 0.3.28",
|
||||||
|
"gpui",
|
||||||
|
"isahc",
|
||||||
|
"regex",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alacritty_config"
|
name = "alacritty_config"
|
||||||
version = "0.1.2-dev"
|
version = "0.1.2-dev"
|
||||||
|
@ -272,6 +286,7 @@ checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16"
|
||||||
name = "assistant"
|
name = "assistant"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"ai",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"chrono",
|
"chrono",
|
||||||
"client",
|
"client",
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"crates/activity_indicator",
|
"crates/activity_indicator",
|
||||||
|
"crates/ai",
|
||||||
"crates/assistant",
|
"crates/assistant",
|
||||||
"crates/audio",
|
"crates/audio",
|
||||||
"crates/auto_update",
|
"crates/auto_update",
|
||||||
|
|
21
crates/ai/Cargo.toml
Normal file
21
crates/ai/Cargo.toml
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
[package]
|
||||||
|
name = "ai"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/ai.rs"
|
||||||
|
doctest = false
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
gpui = { path = "../gpui" }
|
||||||
|
anyhow.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
isahc.workspace = true
|
||||||
|
regex.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
ctor.workspace = true
|
1
crates/ai/src/ai.rs
Normal file
1
crates/ai/src/ai.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
pub mod completion;
|
212
crates/ai/src/completion.rs
Normal file
212
crates/ai/src/completion.rs
Normal file
|
@ -0,0 +1,212 @@
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use futures::{
|
||||||
|
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||||
|
Stream, StreamExt,
|
||||||
|
};
|
||||||
|
use gpui::executor::Background;
|
||||||
|
use isahc::{http::StatusCode, Request, RequestExt};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{
|
||||||
|
fmt::{self, Display},
|
||||||
|
io,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
System,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Role {
|
||||||
|
pub fn cycle(&mut self) {
|
||||||
|
*self = match self {
|
||||||
|
Role::User => Role::Assistant,
|
||||||
|
Role::Assistant => Role::System,
|
||||||
|
Role::System => Role::User,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for Role {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Role::User => write!(f, "User"),
|
||||||
|
Role::Assistant => write!(f, "Assistant"),
|
||||||
|
Role::System => write!(f, "System"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct RequestMessage {
|
||||||
|
pub role: Role,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Serialize)]
|
||||||
|
pub struct OpenAIRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<RequestMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct ResponseMessage {
|
||||||
|
pub role: Option<Role>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct OpenAIUsage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct ChatChoiceDelta {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: ResponseMessage,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct OpenAIResponseStreamEvent {
|
||||||
|
pub id: Option<String>,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u32,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatChoiceDelta>,
|
||||||
|
pub usage: Option<OpenAIUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stream_completion(
|
||||||
|
api_key: String,
|
||||||
|
executor: Arc<Background>,
|
||||||
|
mut request: OpenAIRequest,
|
||||||
|
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||||
|
request.stream = true;
|
||||||
|
|
||||||
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||||
|
|
||||||
|
let json_data = serde_json::to_string(&request)?;
|
||||||
|
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
.body(json_data)?
|
||||||
|
.send_async()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if status == StatusCode::OK {
|
||||||
|
executor
|
||||||
|
.spawn(async move {
|
||||||
|
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||||
|
|
||||||
|
fn parse_line(
|
||||||
|
line: Result<String, io::Error>,
|
||||||
|
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||||
|
if let Some(data) = line?.strip_prefix("data: ") {
|
||||||
|
let event = serde_json::from_str(&data)?;
|
||||||
|
Ok(Some(event))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some(line) = lines.next().await {
|
||||||
|
if let Some(event) = parse_line(line).transpose() {
|
||||||
|
let done = event.as_ref().map_or(false, |event| {
|
||||||
|
event
|
||||||
|
.choices
|
||||||
|
.last()
|
||||||
|
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||||
|
});
|
||||||
|
if tx.unbounded_send(event).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
} else {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIResponse {
|
||||||
|
error: OpenAIError,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIError {
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||||
|
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||||
|
"Failed to connect to OpenAI API: {}",
|
||||||
|
response.error.message,
|
||||||
|
)),
|
||||||
|
|
||||||
|
_ => Err(anyhow!(
|
||||||
|
"Failed to connect to OpenAI API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CompletionProvider {
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
prompt: OpenAIRequest,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OpenAICompletionProvider {
|
||||||
|
api_key: String,
|
||||||
|
executor: Arc<Background>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAICompletionProvider {
|
||||||
|
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
|
||||||
|
Self { api_key, executor }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider for OpenAICompletionProvider {
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
prompt: OpenAIRequest,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
||||||
|
async move {
|
||||||
|
let response = request.await?;
|
||||||
|
let stream = response
|
||||||
|
.filter_map(|response| async move {
|
||||||
|
match response {
|
||||||
|
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,6 +9,7 @@ path = "src/assistant.rs"
|
||||||
doctest = false
|
doctest = false
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
ai = { path = "../ai" }
|
||||||
client = { path = "../client" }
|
client = { path = "../client" }
|
||||||
collections = { path = "../collections"}
|
collections = { path = "../collections"}
|
||||||
editor = { path = "../editor" }
|
editor = { path = "../editor" }
|
||||||
|
|
|
@ -3,37 +3,20 @@ mod assistant_settings;
|
||||||
mod codegen;
|
mod codegen;
|
||||||
mod streaming_diff;
|
mod streaming_diff;
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use ai::completion::Role;
|
||||||
|
use anyhow::Result;
|
||||||
pub use assistant_panel::AssistantPanel;
|
pub use assistant_panel::AssistantPanel;
|
||||||
use assistant_settings::OpenAIModel;
|
use assistant_settings::OpenAIModel;
|
||||||
use chrono::{DateTime, Local};
|
use chrono::{DateTime, Local};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
use futures::StreamExt;
|
||||||
use gpui::{executor::Background, AppContext};
|
use gpui::AppContext;
|
||||||
use isahc::{http::StatusCode, Request, RequestExt};
|
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
|
||||||
cmp::Reverse,
|
|
||||||
ffi::OsStr,
|
|
||||||
fmt::{self, Display},
|
|
||||||
io,
|
|
||||||
path::PathBuf,
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
use util::paths::CONVERSATIONS_DIR;
|
use util::paths::CONVERSATIONS_DIR;
|
||||||
|
|
||||||
const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
|
||||||
|
|
||||||
// Data types for chat completion requests
|
|
||||||
#[derive(Debug, Default, Serialize)]
|
|
||||||
pub struct OpenAIRequest {
|
|
||||||
model: String,
|
|
||||||
messages: Vec<RequestMessage>,
|
|
||||||
stream: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
|
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
|
||||||
)]
|
)]
|
||||||
|
@ -116,175 +99,10 @@ impl SavedConversationMetadata {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
struct RequestMessage {
|
|
||||||
role: Role,
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct ResponseMessage {
|
|
||||||
role: Option<Role>,
|
|
||||||
content: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
enum Role {
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
System,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Role {
|
|
||||||
pub fn cycle(&mut self) {
|
|
||||||
*self = match self {
|
|
||||||
Role::User => Role::Assistant,
|
|
||||||
Role::Assistant => Role::System,
|
|
||||||
Role::System => Role::User,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for Role {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Role::User => write!(f, "User"),
|
|
||||||
Role::Assistant => write!(f, "Assistant"),
|
|
||||||
Role::System => write!(f, "System"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct OpenAIResponseStreamEvent {
|
|
||||||
pub id: Option<String>,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u32,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatChoiceDelta>,
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct ChatChoiceDelta {
|
|
||||||
pub index: u32,
|
|
||||||
pub delta: ResponseMessage,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
struct OpenAIUsage {
|
|
||||||
prompt_tokens: u64,
|
|
||||||
completion_tokens: u64,
|
|
||||||
total_tokens: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
struct OpenAIChoice {
|
|
||||||
text: String,
|
|
||||||
index: u32,
|
|
||||||
logprobs: Option<serde_json::Value>,
|
|
||||||
finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn init(cx: &mut AppContext) {
|
pub fn init(cx: &mut AppContext) {
|
||||||
assistant_panel::init(cx);
|
assistant_panel::init(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_completion(
|
|
||||||
api_key: String,
|
|
||||||
executor: Arc<Background>,
|
|
||||||
mut request: OpenAIRequest,
|
|
||||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
|
||||||
request.stream = true;
|
|
||||||
|
|
||||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
|
||||||
|
|
||||||
let json_data = serde_json::to_string(&request)?;
|
|
||||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.body(json_data)?
|
|
||||||
.send_async()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
if status == StatusCode::OK {
|
|
||||||
executor
|
|
||||||
.spawn(async move {
|
|
||||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
|
||||||
|
|
||||||
fn parse_line(
|
|
||||||
line: Result<String, io::Error>,
|
|
||||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
|
||||||
if let Some(data) = line?.strip_prefix("data: ") {
|
|
||||||
let event = serde_json::from_str(&data)?;
|
|
||||||
Ok(Some(event))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
while let Some(line) = lines.next().await {
|
|
||||||
if let Some(event) = parse_line(line).transpose() {
|
|
||||||
let done = event.as_ref().map_or(false, |event| {
|
|
||||||
event
|
|
||||||
.choices
|
|
||||||
.last()
|
|
||||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
|
||||||
});
|
|
||||||
if tx.unbounded_send(event).is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if done {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
|
|
||||||
Ok(rx)
|
|
||||||
} else {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIResponse {
|
|
||||||
error: OpenAIError,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIError {
|
|
||||||
message: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
|
||||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
|
||||||
"Failed to connect to OpenAI API: {}",
|
|
||||||
response.error.message,
|
|
||||||
)),
|
|
||||||
|
|
||||||
_ => Err(anyhow!(
|
|
||||||
"Failed to connect to OpenAI API: {} {}",
|
|
||||||
response.status(),
|
|
||||||
body,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[ctor::ctor]
|
#[ctor::ctor]
|
||||||
fn init_logger() {
|
fn init_logger() {
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
|
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
|
||||||
codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider},
|
codegen::{self, Codegen, CodegenKind},
|
||||||
stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
|
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
||||||
Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
|
SavedMessage,
|
||||||
|
};
|
||||||
|
use ai::completion::{
|
||||||
|
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use chrono::{DateTime, Local};
|
use chrono::{DateTime, Local};
|
||||||
|
|
|
@ -1,59 +1,14 @@
|
||||||
use crate::{
|
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||||
stream_completion,
|
use ai::completion::{CompletionProvider, OpenAIRequest};
|
||||||
streaming_diff::{Hunk, StreamingDiff},
|
|
||||||
OpenAIRequest,
|
|
||||||
};
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use editor::{
|
use editor::{
|
||||||
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
|
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
|
||||||
};
|
};
|
||||||
use futures::{
|
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||||
channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
|
use gpui::{Entity, ModelContext, ModelHandle, Task};
|
||||||
};
|
|
||||||
use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
|
|
||||||
use language::{Rope, TransactionId};
|
use language::{Rope, TransactionId};
|
||||||
use std::{cmp, future, ops::Range, sync::Arc};
|
use std::{cmp, future, ops::Range, sync::Arc};
|
||||||
|
|
||||||
pub trait CompletionProvider {
|
|
||||||
fn complete(
|
|
||||||
&self,
|
|
||||||
prompt: OpenAIRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct OpenAICompletionProvider {
|
|
||||||
api_key: String,
|
|
||||||
executor: Arc<Background>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAICompletionProvider {
|
|
||||||
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
|
|
||||||
Self { api_key, executor }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionProvider for OpenAICompletionProvider {
|
|
||||||
fn complete(
|
|
||||||
&self,
|
|
||||||
prompt: OpenAIRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
|
||||||
async move {
|
|
||||||
let response = request.await?;
|
|
||||||
let stream = response
|
|
||||||
.filter_map(|response| async move {
|
|
||||||
match response {
|
|
||||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
|
||||||
Err(error) => Some(Err(error)),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed();
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum Event {
|
pub enum Event {
|
||||||
Finished,
|
Finished,
|
||||||
Undone,
|
Undone,
|
||||||
|
@ -397,13 +352,17 @@ fn strip_markdown_codeblock(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use futures::stream;
|
use futures::{
|
||||||
|
future::BoxFuture,
|
||||||
|
stream::{self, BoxStream},
|
||||||
|
};
|
||||||
use gpui::{executor::Deterministic, TestAppContext};
|
use gpui::{executor::Deterministic, TestAppContext};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
|
use smol::future::FutureExt;
|
||||||
|
|
||||||
#[gpui::test(iterations = 10)]
|
#[gpui::test(iterations = 10)]
|
||||||
async fn test_transform_autoindent(
|
async fn test_transform_autoindent(
|
||||||
|
|
|
@ -5,9 +5,9 @@ pub mod only_instance;
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
pub mod test;
|
pub mod test;
|
||||||
|
|
||||||
use assistant::AssistantPanel;
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use assets::Assets;
|
use assets::Assets;
|
||||||
|
use assistant::AssistantPanel;
|
||||||
use breadcrumbs::Breadcrumbs;
|
use breadcrumbs::Breadcrumbs;
|
||||||
pub use client;
|
pub use client;
|
||||||
use collab_ui::CollabTitlebarItem; // TODO: Add back toggle collab ui shortcut
|
use collab_ui::CollabTitlebarItem; // TODO: Add back toggle collab ui shortcut
|
||||||
|
@ -2418,7 +2418,7 @@ mod tests {
|
||||||
pane::init(cx);
|
pane::init(cx);
|
||||||
project_panel::init((), cx);
|
project_panel::init((), cx);
|
||||||
terminal_view::init(cx);
|
terminal_view::init(cx);
|
||||||
ai::init(cx);
|
assistant::init(cx);
|
||||||
app_state
|
app_state
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue