From 0b57ab730332dbf0033d652b4b531b2898c88039 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:34:22 +0200 Subject: [PATCH 01/23] cleaned up truncate vs truncate start --- crates/ai/src/models.rs | 37 ++++++++++++++----------- crates/ai/src/templates/base.rs | 33 ++++++++++++++-------- crates/ai/src/templates/file_context.rs | 12 +++++--- crates/ai/src/templates/generate.rs | 6 +++- 4 files changed, 56 insertions(+), 32 deletions(-) diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index d0206cc41c..afb8496156 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -2,11 +2,20 @@ use anyhow::anyhow; use tiktoken_rs::CoreBPE; use util::ResultExt; +pub enum TruncationDirection { + Start, + End, +} + pub trait LanguageModel { fn name(&self) -> String; fn count_tokens(&self, content: &str) -> anyhow::Result; - fn truncate(&self, content: &str, length: usize) -> anyhow::Result; - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } @@ -36,23 +45,19 @@ impl LanguageModel for OpenAILanguageModel { Err(anyhow!("bpe for open ai model was not retrieved")) } } - fn truncate(&self, content: &str, length: usize) -> anyhow::Result { + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { if let Some(bpe) = &self.bpe { let tokens = bpe.encode_with_special_tokens(content); if tokens.len() > length { - bpe.decode(tokens[..length].to_vec()) - } else { - bpe.decode(tokens) - } - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - let tokens = bpe.encode_with_special_tokens(content); - if tokens.len() > length { - bpe.decode(tokens[length..].to_vec()) + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } } else { bpe.decode(tokens) } diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index bda1d6c30e..e5ac414bc1 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -125,6 +125,8 @@ impl PromptChain { #[cfg(test)] pub(crate) mod tests { + use crate::models::TruncationDirection; + use super::*; #[test] @@ -141,7 +143,11 @@ pub(crate) mod tests { let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length)?; + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::Start, + )?; token_count = max_token_length; } } @@ -162,7 +168,11 @@ pub(crate) mod tests { let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length)?; + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::Start, + )?; token_count = max_token_length; } } @@ -183,19 +193,20 @@ pub(crate) mod tests { fn count_tokens(&self, content: &str) -> anyhow::Result { anyhow::Ok(content.chars().collect::>().len()) } - fn truncate(&self, content: &str, length: usize) -> anyhow::Result { - anyhow::Ok( - content.chars().collect::>()[..length] + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] .into_iter() .collect::(), - ) - } - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { - anyhow::Ok( - content.chars().collect::>()[length..] + TruncationDirection::Start => content.chars().collect::>()[length..] .into_iter() .collect::(), - ) + }) } fn capacity(&self) -> anyhow::Result { anyhow::Ok(self.capacity) diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 1afd61192e..1517134abb 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -3,6 +3,7 @@ use language::BufferSnapshot; use language::ToOffset; use crate::models::LanguageModel; +use crate::models::TruncationDirection; use crate::templates::base::PromptArguments; use crate::templates::base::PromptTemplate; use std::fmt::Write; @@ -70,8 +71,9 @@ fn retrieve_context( }; let truncated_start_window = - model.truncate_start(&start_window, start_goal_tokens)?; - let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; writeln!( prompt, "{truncated_start_window}{selected_window}{truncated_end_window}" @@ -89,7 +91,7 @@ fn retrieve_context( if let Some(max_token_count) = max_token_count { if model.count_tokens(&prompt)? > max_token_count { truncated = true; - prompt = model.truncate(&prompt, max_token_count)?; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; } } } @@ -148,7 +150,9 @@ impl PromptTemplate for FileContext { // Really dumb truncation strategy if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate(&prompt, max_tokens)?; + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; } let token_count = args.model.count_tokens(&prompt)?; diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs index 1eeb197f93..c9541c6b44 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/templates/generate.rs @@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent { // Really dumb truncation strategy if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate(&prompt, max_tokens)?; + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; } let token_count = args.model.count_tokens(&prompt)?; From a62baf34f2e3bef619ba57e557cb30baa6356b29 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:46:49 +0200 Subject: [PATCH 02/23] rename templates to prompts in ai crate --- crates/ai/src/ai.rs | 2 +- crates/ai/src/{templates => prompts}/base.rs | 2 +- crates/ai/src/{templates => prompts}/file_context.rs | 4 ++-- crates/ai/src/{templates => prompts}/generate.rs | 2 +- crates/ai/src/{templates => prompts}/mod.rs | 0 crates/ai/src/{templates => prompts}/preamble.rs | 2 +- .../src/{templates => prompts}/repository_context.rs | 2 +- crates/assistant/src/assistant_panel.rs | 2 +- crates/assistant/src/prompts.rs | 10 +++++----- 9 files changed, 13 insertions(+), 13 deletions(-) rename crates/ai/src/{templates => prompts}/base.rs (99%) rename crates/ai/src/{templates => prompts}/file_context.rs (98%) rename crates/ai/src/{templates => prompts}/generate.rs (97%) rename crates/ai/src/{templates => prompts}/mod.rs (100%) rename crates/ai/src/{templates => prompts}/preamble.rs (95%) rename crates/ai/src/{templates => prompts}/repository_context.rs (98%) diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index f168c15793..c0b78b74cf 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,4 +1,4 @@ pub mod completion; pub mod embedding; pub mod models; -pub mod templates; +pub mod prompts; diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/prompts/base.rs similarity index 99% rename from crates/ai/src/templates/base.rs rename to crates/ai/src/prompts/base.rs index e5ac414bc1..f0ff597e63 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/prompts/base.rs @@ -6,7 +6,7 @@ use language::BufferSnapshot; use util::ResultExt; use crate::models::LanguageModel; -use crate::templates::repository_context::PromptCodeSnippet; +use crate::prompts::repository_context::PromptCodeSnippet; pub(crate) enum PromptFileType { Text, diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/prompts/file_context.rs similarity index 98% rename from crates/ai/src/templates/file_context.rs rename to crates/ai/src/prompts/file_context.rs index 1517134abb..f108a62f6f 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/prompts/file_context.rs @@ -4,8 +4,8 @@ use language::ToOffset; use crate::models::LanguageModel; use crate::models::TruncationDirection; -use crate::templates::base::PromptArguments; -use crate::templates::base::PromptTemplate; +use crate::prompts::base::PromptArguments; +use crate::prompts::base::PromptTemplate; use std::fmt::Write; use std::ops::Range; use std::sync::Arc; diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/prompts/generate.rs similarity index 97% rename from crates/ai/src/templates/generate.rs rename to crates/ai/src/prompts/generate.rs index c9541c6b44..c7be620107 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/prompts/generate.rs @@ -1,4 +1,4 @@ -use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; use anyhow::anyhow; use std::fmt::Write; diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/prompts/mod.rs similarity index 100% rename from crates/ai/src/templates/mod.rs rename to crates/ai/src/prompts/mod.rs diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/prompts/preamble.rs similarity index 95% rename from crates/ai/src/templates/preamble.rs rename to crates/ai/src/prompts/preamble.rs index 9eabaaeb97..92e0edeb78 100644 --- a/crates/ai/src/templates/preamble.rs +++ b/crates/ai/src/prompts/preamble.rs @@ -1,4 +1,4 @@ -use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; use std::fmt::Write; pub struct EngineerPreamble {} diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/prompts/repository_context.rs similarity index 98% rename from crates/ai/src/templates/repository_context.rs rename to crates/ai/src/prompts/repository_context.rs index a8e7f4b5af..c21b0f995c 100644 --- a/crates/ai/src/templates/repository_context.rs +++ b/crates/ai/src/prompts/repository_context.rs @@ -1,4 +1,4 @@ -use crate::templates::base::{PromptArguments, PromptTemplate}; +use crate::prompts::base::{PromptArguments, PromptTemplate}; use std::fmt::Write; use std::{ops::Range, path::PathBuf}; diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ca8c54a285..64eff04b8d 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -9,7 +9,7 @@ use ai::{ completion::{ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, }, - templates::repository_context::PromptCodeSnippet, + prompts::repository_context::PromptCodeSnippet, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index dffcbc2923..8fff232fdb 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,9 +1,9 @@ use ai::models::{LanguageModel, OpenAILanguageModel}; -use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; -use ai::templates::file_context::FileContext; -use ai::templates::generate::GenerateInlineContent; -use ai::templates::preamble::EngineerPreamble; -use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; +use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::prompts::file_context::FileContext; +use ai::prompts::generate::GenerateInlineContent; +use ai::prompts::preamble::EngineerPreamble; +use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; use std::ops::Range; From 3712794e561b7aa6068ee3de0fc411d5cb311566 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:47:28 +0200 Subject: [PATCH 03/23] move OpenAILanguageModel to providers folder --- crates/ai/src/ai.rs | 1 + crates/ai/src/models.rs | 55 ----------------------- crates/ai/src/providers/mod.rs | 1 + crates/ai/src/providers/open_ai/mod.rs | 2 + crates/ai/src/providers/open_ai/model.rs | 56 ++++++++++++++++++++++++ crates/assistant/src/prompts.rs | 3 +- 6 files changed, 62 insertions(+), 56 deletions(-) create mode 100644 crates/ai/src/providers/mod.rs create mode 100644 crates/ai/src/providers/open_ai/mod.rs create mode 100644 crates/ai/src/providers/open_ai/model.rs diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index c0b78b74cf..a3ae2fcf7f 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -2,3 +2,4 @@ pub mod completion; pub mod embedding; pub mod models; pub mod prompts; +pub mod providers; diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index afb8496156..1db3d58c6f 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -1,7 +1,3 @@ -use anyhow::anyhow; -use tiktoken_rs::CoreBPE; -use util::ResultExt; - pub enum TruncationDirection { Start, End, @@ -18,54 +14,3 @@ pub trait LanguageModel { ) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } - -pub struct OpenAILanguageModel { - name: String, - bpe: Option, -} - -impl OpenAILanguageModel { - pub fn load(model_name: &str) -> Self { - let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); - OpenAILanguageModel { - name: model_name.to_string(), - bpe, - } - } -} - -impl LanguageModel for OpenAILanguageModel { - fn name(&self) -> String { - self.name.clone() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - anyhow::Ok(bpe.encode_with_special_tokens(content).len()) - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - let tokens = bpe.encode_with_special_tokens(content); - if tokens.len() > length { - match direction { - TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), - TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), - } - } else { - bpe.decode(tokens) - } - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) - } -} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs new file mode 100644 index 0000000000..acd0f9d910 --- /dev/null +++ b/crates/ai/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000..8d8489e187 --- /dev/null +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -0,0 +1,2 @@ +pub mod model; +pub use model::OpenAILanguageModel; diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs new file mode 100644 index 0000000000..42523f3df4 --- /dev/null +++ b/crates/ai/src/providers/open_ai/model.rs @@ -0,0 +1,56 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 8fff232fdb..25af023c40 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,9 +1,10 @@ -use ai::models::{LanguageModel, OpenAILanguageModel}; +use ai::models::LanguageModel; use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; use ai::prompts::file_context::FileContext; use ai::prompts::generate::GenerateInlineContent; use ai::prompts::preamble::EngineerPreamble; use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; +use ai::providers::open_ai::OpenAILanguageModel; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; use std::ops::Range; From 05ae978cb773978fcb16c81d14e8b4cd4907decd Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:57:13 +0200 Subject: [PATCH 04/23] move OpenAICompletionProvider to providers location --- crates/ai/src/completion.rs | 209 +----------------- crates/ai/src/providers/open_ai/completion.rs | 209 ++++++++++++++++++ crates/ai/src/providers/open_ai/mod.rs | 2 + crates/assistant/src/assistant.rs | 2 +- crates/assistant/src/assistant_panel.rs | 10 +- crates/assistant/src/codegen.rs | 3 +- 6 files changed, 222 insertions(+), 213 deletions(-) create mode 100644 crates/ai/src/providers/open_ai/completion.rs diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index de6ce9da71..f45893898f 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,177 +1,7 @@ -use anyhow::{anyhow, Result}; -use futures::{ - future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, - Stream, StreamExt, -}; -use gpui::executor::Background; -use isahc::{http::StatusCode, Request, RequestExt}; -use serde::{Deserialize, Serialize}; -use std::{ - fmt::{self, Display}, - io, - sync::Arc, -}; +use anyhow::Result; +use futures::{future::BoxFuture, stream::BoxStream}; -pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; - -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, -} - -impl Role { - pub fn cycle(&mut self) { - *self = match self { - Role::User => Role::Assistant, - Role::Assistant => Role::System, - Role::System => Role::User, - } - } -} - -impl Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "User"), - Role::Assistant => write!(f, "Assistant"), - Role::System => write!(f, "System"), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct RequestMessage { - pub role: Role, - pub content: String, -} - -#[derive(Debug, Default, Serialize)] -pub struct OpenAIRequest { - pub model: String, - pub messages: Vec, - pub stream: bool, - pub stop: Vec, - pub temperature: f32, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ResponseMessage { - pub role: Option, - pub content: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAIUsage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Deserialize, Debug)] -pub struct ChatChoiceDelta { - pub index: u32, - pub delta: ResponseMessage, - pub finish_reason: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAIResponseStreamEvent { - pub id: Option, - pub object: String, - pub created: u32, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -pub async fn stream_completion( - api_key: String, - executor: Arc, - mut request: OpenAIRequest, -) -> Result>> { - request.stream = true; - - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - - let json_data = serde_json::to_string(&request)?; - let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(json_data)? - .send_async() - .await?; - - let status = response.status(); - if status == StatusCode::OK { - executor - .spawn(async move { - let mut lines = BufReader::new(response.body_mut()).lines(); - - fn parse_line( - line: Result, - ) -> Result> { - if let Some(data) = line?.strip_prefix("data: ") { - let event = serde_json::from_str(&data)?; - Ok(Some(event)) - } else { - Ok(None) - } - } - - while let Some(line) = lines.next().await { - if let Some(event) = parse_line(line).transpose() { - let done = event.as_ref().map_or(false, |event| { - event - .choices - .last() - .map_or(false, |choice| choice.finish_reason.is_some()) - }); - if tx.unbounded_send(event).is_err() { - break; - } - - if done { - break; - } - } - } - - anyhow::Ok(()) - }) - .detach(); - - Ok(rx) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAIResponse { - error: OpenAIError, - } - - #[derive(Deserialize)] - struct OpenAIError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to OpenAI API: {}", - response.error.message, - )), - - _ => Err(anyhow!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - )), - } - } -} +use crate::providers::open_ai::completion::OpenAIRequest; pub trait CompletionProvider { fn complete( @@ -179,36 +9,3 @@ pub trait CompletionProvider { prompt: OpenAIRequest, ) -> BoxFuture<'static, Result>>>; } - -pub struct OpenAICompletionProvider { - api_key: String, - executor: Arc, -} - -impl OpenAICompletionProvider { - pub fn new(api_key: String, executor: Arc) -> Self { - Self { api_key, executor } - } -} - -impl CompletionProvider for OpenAICompletionProvider { - fn complete( - &self, - prompt: OpenAIRequest, - ) -> BoxFuture<'static, Result>>> { - let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); - async move { - let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) - } - .boxed() - } -} diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs new file mode 100644 index 0000000000..bb6138eee3 --- /dev/null +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -0,0 +1,209 @@ +use anyhow::{anyhow, Result}; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui::executor::Background; +use isahc::{http::StatusCode, Request, RequestExt}; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{self, Display}, + io, + sync::Arc, +}; + +use crate::completion::CompletionProvider; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + api_key: String, + executor: Arc, + mut request: OpenAIRequest, +) -> Result>> { + request.stream = true; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = serde_json::to_string(&request)?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +pub struct OpenAICompletionProvider { + api_key: String, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(api_key: String, executor: Arc) -> Self { + Self { api_key, executor } + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn complete( + &self, + prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>> { + let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } +} diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 8d8489e187..26f3068ca1 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,2 +1,4 @@ +pub mod completion; pub mod model; +pub use completion::*; pub use model::OpenAILanguageModel; diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 6c9b14333e..91d61a19f9 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -4,7 +4,7 @@ mod codegen; mod prompts; mod streaming_diff; -use ai::completion::Role; +use ai::providers::open_ai::Role; use anyhow::Result; pub use assistant_panel::AssistantPanel; use assistant_settings::OpenAIModel; diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 64eff04b8d..9b749e5091 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -5,12 +5,12 @@ use crate::{ MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; -use ai::{ - completion::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, - }, - prompts::repository_context::PromptCodeSnippet, + +use ai::providers::open_ai::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, }; + +use ai::prompts::repository_context::PromptCodeSnippet; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings}; diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index b6ef6b5cfa..66d2f60690 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,5 +1,6 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; -use ai::completion::{CompletionProvider, OpenAIRequest}; +use ai::completion::CompletionProvider; +use ai::providers::open_ai::OpenAIRequest; use anyhow::Result; use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; From d813ae88458ed5e14899c5ebdd4437daa033ae6e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 14:33:19 +0200 Subject: [PATCH 05/23] replace OpenAIRequest with more generalized Box --- crates/ai/src/completion.rs | 6 +++-- crates/ai/src/providers/dummy.rs | 13 ++++++++++ crates/ai/src/providers/mod.rs | 1 + crates/ai/src/providers/open_ai/completion.rs | 16 +++++++----- crates/assistant/src/assistant_panel.rs | 20 +++++++++------ crates/assistant/src/codegen.rs | 25 +++++++++++++------ 6 files changed, 58 insertions(+), 23 deletions(-) create mode 100644 crates/ai/src/providers/dummy.rs diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index f45893898f..ba89c869d2 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,11 +1,13 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; -use crate::providers::open_ai::completion::OpenAIRequest; +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; +} pub trait CompletionProvider { fn complete( &self, - prompt: OpenAIRequest, + prompt: Box, ) -> BoxFuture<'static, Result>>>; } diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs new file mode 100644 index 0000000000..be42b13f2f --- /dev/null +++ b/crates/ai/src/providers/dummy.rs @@ -0,0 +1,13 @@ +use crate::completion::CompletionRequest; +use serde::Serialize; + +#[derive(Serialize)] +pub struct DummyCompletionRequest { + pub name: String, +} + +impl CompletionRequest for DummyCompletionRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs index acd0f9d910..7a7092baf3 100644 --- a/crates/ai/src/providers/mod.rs +++ b/crates/ai/src/providers/mod.rs @@ -1 +1,2 @@ +pub mod dummy; pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index bb6138eee3..95ed13c0dd 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -12,7 +12,7 @@ use std::{ sync::Arc, }; -use crate::completion::CompletionProvider; +use crate::completion::{CompletionProvider, CompletionRequest}; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; @@ -59,6 +59,12 @@ pub struct OpenAIRequest { pub temperature: f32, } +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ResponseMessage { pub role: Option, @@ -92,13 +98,11 @@ pub struct OpenAIResponseStreamEvent { pub async fn stream_completion( api_key: String, executor: Arc, - mut request: OpenAIRequest, + request: Box, ) -> Result>> { - request.stream = true; - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - let json_data = serde_json::to_string(&request)?; + let json_data = request.data()?; let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) @@ -189,7 +193,7 @@ impl OpenAICompletionProvider { impl CompletionProvider for OpenAICompletionProvider { fn complete( &self, - prompt: OpenAIRequest, + prompt: Box, ) -> BoxFuture<'static, Result>>> { let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); async move { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9b749e5091..ec16c8fd04 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -6,8 +6,11 @@ use crate::{ SavedMessage, }; -use ai::providers::open_ai::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, +use ai::{ + completion::CompletionRequest, + providers::open_ai::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + }, }; use ai::prompts::repository_context::PromptCodeSnippet; @@ -745,13 +748,14 @@ impl AssistantPanel { content: prompt, }); - let request = OpenAIRequest { + let request = Box::new(OpenAIRequest { model: model.full_name().into(), messages, stream: true, stop: vec!["|END|>".to_string()], temperature, - }; + }); + codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); anyhow::Ok(()) }) @@ -1735,7 +1739,7 @@ impl Conversation { return Default::default(); }; - let request = OpenAIRequest { + let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), messages: self .messages(cx) @@ -1745,7 +1749,7 @@ impl Conversation { stream: true, stop: vec![], temperature: 1.0, - }; + }); let stream = stream_completion(api_key, cx.background().clone(), request); let assistant_message = self @@ -2025,13 +2029,13 @@ impl Conversation { "Summarize the conversation into a short title without punctuation" .into(), })); - let request = OpenAIRequest { + let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), messages: messages.collect(), stream: true, stop: vec![], temperature: 1.0, - }; + }); let stream = stream_completion(api_key, cx.background().clone(), request); self.pending_summary = cx.spawn(|this, mut cx| { diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 66d2f60690..e535eca144 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,6 +1,5 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; -use ai::completion::CompletionProvider; -use ai::providers::open_ai::OpenAIRequest; +use ai::completion::{CompletionProvider, CompletionRequest}; use anyhow::Result; use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; @@ -96,7 +95,7 @@ impl Codegen { self.error.as_ref() } - pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { + pub fn start(&mut self, prompt: Box, cx: &mut ModelContext) { let range = self.range(); let snapshot = self.snapshot.clone(); let selected_text = snapshot @@ -336,6 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; + use ai::providers::dummy::DummyCompletionRequest; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -381,7 +381,10 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( " let mut x = 0;\n", @@ -443,7 +446,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "t mut x = 0;\n", @@ -505,7 +512,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "let mut x = 0;\n", @@ -617,7 +628,7 @@ mod tests { impl CompletionProvider for TestCompletionProvider { fn complete( &self, - _prompt: OpenAIRequest, + _prompt: Box, ) -> BoxFuture<'static, Result>>> { let (tx, rx) = mpsc::channel(1); *self.last_completion_tx.lock() = Some(tx); From d1dec8314adb8be628912642efeffb572ef83b71 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 14:46:22 +0200 Subject: [PATCH 06/23] move OpenAIEmbeddings to OpenAIEmbeddingProvider in providers folder --- crates/ai/src/embedding.rs | 287 +----------------- crates/ai/src/providers/dummy.rs | 37 ++- crates/ai/src/providers/open_ai/embedding.rs | 252 +++++++++++++++ crates/ai/src/providers/open_ai/mod.rs | 3 + crates/semantic_index/src/semantic_index.rs | 5 +- .../src/semantic_index_tests.rs | 19 +- crates/zed/examples/semantic_index_eval.rs | 4 +- 7 files changed, 308 insertions(+), 299 deletions(-) create mode 100644 crates/ai/src/providers/open_ai/embedding.rs diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 4587ece0a2..05798c3f5d 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -1,30 +1,9 @@ -use anyhow::{anyhow, Result}; +use anyhow::Result; use async_trait::async_trait; -use futures::AsyncReadExt; -use gpui::executor::Background; -use gpui::serde_json; -use isahc::http::StatusCode; -use isahc::prelude::Configurable; -use isahc::{AsyncBody, Response}; -use lazy_static::lazy_static; use ordered_float::OrderedFloat; -use parking_lot::Mutex; -use parse_duration::parse; -use postage::watch; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; -use serde::{Deserialize, Serialize}; -use std::env; -use std::ops::Add; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tiktoken_rs::{cl100k_base, CoreBPE}; -use util::http::{HttpClient, Request}; - -lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); - static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); -} +use std::time::Instant; #[derive(Debug, PartialEq, Clone)] pub struct Embedding(pub Vec); @@ -85,39 +64,6 @@ impl Embedding { } } -#[derive(Clone)] -pub struct OpenAIEmbeddings { - pub client: Arc, - pub executor: Arc, - rate_limit_count_rx: watch::Receiver>, - rate_limit_count_tx: Arc>>>, -} - -#[derive(Serialize)] -struct OpenAIEmbeddingRequest<'a> { - model: &'static str, - input: Vec<&'a str>, -} - -#[derive(Deserialize)] -struct OpenAIEmbeddingResponse { - data: Vec, - usage: OpenAIEmbeddingUsage, -} - -#[derive(Debug, Deserialize)] -struct OpenAIEmbedding { - embedding: Vec, - index: usize, - object: String, -} - -#[derive(Deserialize)] -struct OpenAIEmbeddingUsage { - prompt_tokens: usize, - total_tokens: usize, -} - #[async_trait] pub trait EmbeddingProvider: Sync + Send { fn is_authenticated(&self) -> bool; @@ -127,235 +73,6 @@ pub trait EmbeddingProvider: Sync + Send { fn rate_limit_expiration(&self) -> Option; } -pub struct DummyEmbeddings {} - -#[async_trait] -impl EmbeddingProvider for DummyEmbeddings { - fn is_authenticated(&self) -> bool { - true - } - fn rate_limit_expiration(&self) -> Option { - None - } - async fn embed_batch(&self, spans: Vec) -> Result> { - // 1024 is the OpenAI Embeddings size for ada models. - // the model we will likely be starting with. - let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); - return Ok(vec![dummy_vec; spans.len()]); - } - - fn max_tokens_per_batch(&self) -> usize { - OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let token_count = tokens.len(); - let output = if token_count > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); - new_input.ok().unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } -} - -const OPENAI_INPUT_LIMIT: usize = 8190; - -impl OpenAIEmbeddings { - pub fn new(client: Arc, executor: Arc) -> Self { - let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); - let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); - - OpenAIEmbeddings { - client, - executor, - rate_limit_count_rx, - rate_limit_count_tx, - } - } - - fn resolve_rate_limit(&self) { - let reset_time = *self.rate_limit_count_tx.lock().borrow(); - - if let Some(reset_time) = reset_time { - if Instant::now() >= reset_time { - *self.rate_limit_count_tx.lock().borrow_mut() = None - } - } - - log::trace!( - "resolving reset time: {:?}", - *self.rate_limit_count_tx.lock().borrow() - ); - } - - fn update_reset_time(&self, reset_time: Instant) { - let original_time = *self.rate_limit_count_tx.lock().borrow(); - - let updated_time = if let Some(original_time) = original_time { - if reset_time < original_time { - Some(reset_time) - } else { - Some(original_time) - } - } else { - Some(reset_time) - }; - - log::trace!("updating rate limit time: {:?}", updated_time); - - *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; - } - async fn send_request( - &self, - api_key: &str, - spans: Vec<&str>, - request_timeout: u64, - ) -> Result> { - let request = Request::post("https://api.openai.com/v1/embeddings") - .redirect_policy(isahc::config::RedirectPolicy::Follow) - .timeout(Duration::from_secs(request_timeout)) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body( - serde_json::to_string(&OpenAIEmbeddingRequest { - input: spans.clone(), - model: "text-embedding-ada-002", - }) - .unwrap() - .into(), - )?; - - Ok(self.client.send(request).await?) - } -} - -#[async_trait] -impl EmbeddingProvider for OpenAIEmbeddings { - fn is_authenticated(&self) -> bool { - OPENAI_API_KEY.as_ref().is_some() - } - fn max_tokens_per_batch(&self) -> usize { - 50000 - } - - fn rate_limit_expiration(&self) -> Option { - *self.rate_limit_count_rx.borrow() - } - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - OPENAI_BPE_TOKENIZER - .decode(tokens.clone()) - .ok() - .unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } - - async fn embed_batch(&self, spans: Vec) -> Result> { - const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; - const MAX_RETRIES: usize = 4; - - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; - - let mut request_number = 0; - let mut rate_limiting = false; - let mut request_timeout: u64 = 15; - let mut response: Response; - while request_number < MAX_RETRIES { - response = self - .send_request( - api_key, - spans.iter().map(|x| &**x).collect(), - request_timeout, - ) - .await?; - - request_number += 1; - - match response.status() { - StatusCode::REQUEST_TIMEOUT => { - request_timeout += 5; - } - StatusCode::OK => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; - - log::trace!( - "openai embedding completed. tokens: {:?}", - response.usage.total_tokens - ); - - // If we complete a request successfully that was previously rate_limited - // resolve the rate limit - if rate_limiting { - self.resolve_rate_limit() - } - - return Ok(response - .data - .into_iter() - .map(|embedding| Embedding::from(embedding.embedding)) - .collect()); - } - StatusCode::TOO_MANY_REQUESTS => { - rate_limiting = true; - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - let delay_duration = { - let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - if let Some(time_to_reset) = - response.headers().get("x-ratelimit-reset-tokens") - { - if let Ok(time_str) = time_to_reset.to_str() { - parse(time_str).unwrap_or(delay) - } else { - delay - } - } else { - delay - } - }; - - // If we've previously rate limited, increment the duration but not the count - let reset_time = Instant::now().add(delay_duration); - self.update_reset_time(reset_time); - - log::trace!( - "openai rate limiting: waiting {:?} until lifted", - &delay_duration - ); - - self.executor.timer(delay_duration).await; - } - _ => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "open ai bad request: {:?} {:?}", - &response.status(), - body - )); - } - } - } - Err(anyhow!("openai max retries")) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs index be42b13f2f..8061a2ca6b 100644 --- a/crates/ai/src/providers/dummy.rs +++ b/crates/ai/src/providers/dummy.rs @@ -1,4 +1,10 @@ -use crate::completion::CompletionRequest; +use std::time::Instant; + +use crate::{ + completion::CompletionRequest, + embedding::{Embedding, EmbeddingProvider}, +}; +use async_trait::async_trait; use serde::Serialize; #[derive(Serialize)] @@ -11,3 +17,32 @@ impl CompletionRequest for DummyCompletionRequest { serde_json::to_string(self) } } + +pub struct DummyEmbeddingProvider {} + +#[async_trait] +impl EmbeddingProvider for DummyEmbeddingProvider { + fn is_authenticated(&self) -> bool { + true + } + fn rate_limit_expiration(&self) -> Option { + None + } + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { + // 1024 is the OpenAI Embeddings size for ada models. + // the model we will likely be starting with. + let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); + return Ok(vec![dummy_vec; spans.len()]); + } + + fn max_tokens_per_batch(&self) -> usize { + 8190 + } + + fn truncate(&self, span: &str) -> (String, usize) { + let truncated = span.chars().collect::>()[..8190] + .iter() + .collect::(); + (truncated, 8190) + } +} diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs new file mode 100644 index 0000000000..35398394dc --- /dev/null +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -0,0 +1,252 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui::executor::Background; +use gpui::serde_json; +use isahc::http::StatusCode; +use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; +use lazy_static::lazy_static; +use parking_lot::Mutex; +use parse_duration::parse; +use postage::watch; +use serde::{Deserialize, Serialize}; +use std::env; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tiktoken_rs::{cl100k_base, CoreBPE}; +use util::http::{HttpClient, Request}; + +use crate::embedding::{Embedding, EmbeddingProvider}; + +lazy_static! { + static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); +} + +#[derive(Clone)] +pub struct OpenAIEmbeddingProvider { + pub client: Arc, + pub executor: Arc, + rate_limit_count_rx: watch::Receiver>, + rate_limit_count_tx: Arc>>>, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +const OPENAI_INPUT_LIMIT: usize = 8190; + +impl OpenAIEmbeddingProvider { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + OpenAIEmbeddingProvider { + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn resolve_rate_limit(&self) { + let reset_time = *self.rate_limit_count_tx.lock().borrow(); + + if let Some(reset_time) = reset_time { + if Instant::now() >= reset_time { + *self.rate_limit_count_tx.lock().borrow_mut() = None + } + } + + log::trace!( + "resolving reset time: {:?}", + *self.rate_limit_count_tx.lock().borrow() + ); + } + + fn update_reset_time(&self, reset_time: Instant) { + let original_time = *self.rate_limit_count_tx.lock().borrow(); + + let updated_time = if let Some(original_time) = original_time { + if reset_time < original_time { + Some(reset_time) + } else { + Some(original_time) + } + } else { + Some(reset_time) + }; + + log::trace!("updating rate limit time: {:?}", updated_time); + + *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; + } + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(request_timeout)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans.clone(), + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + Ok(self.client.send(request).await?) + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn is_authenticated(&self) -> bool { + OPENAI_API_KEY.as_ref().is_some() + } + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn rate_limit_expiration(&self) -> Option { + *self.rate_limit_count_rx.borrow() + } + fn truncate(&self, span: &str) -> (String, usize) { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let output = if tokens.len() > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + OPENAI_BPE_TOKENIZER + .decode(tokens.clone()) + .ok() + .unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + (output, tokens.len()) + } + + async fn embed_batch(&self, spans: Vec) -> Result> { + const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; + const MAX_RETRIES: usize = 4; + + let api_key = OPENAI_API_KEY + .as_ref() + .ok_or_else(|| anyhow!("no api key"))?; + + let mut request_number = 0; + let mut rate_limiting = false; + let mut request_timeout: u64 = 15; + let mut response: Response; + while request_number < MAX_RETRIES { + response = self + .send_request( + api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) + .await?; + + request_number += 1; + + match response.status() { + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::trace!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + + return Ok(response + .data + .into_iter() + .map(|embedding| Embedding::from(embedding.embedding)) + .collect()); + } + StatusCode::TOO_MANY_REQUESTS => { + rate_limiting = true; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + // If we've previously rate limited, increment the duration but not the count + let reset_time = Instant::now().add(delay_duration); + self.update_reset_time(reset_time); + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } + _ => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } + } + } + Err(anyhow!("openai max retries")) + } +} diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 26f3068ca1..67cb2b5315 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,4 +1,7 @@ pub mod completion; +pub mod embedding; pub mod model; + pub use completion::*; +pub use embedding::*; pub use model::OpenAILanguageModel; diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index ecdba43643..926eb3045c 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,7 +7,8 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; -use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; +use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; use collections::{BTreeMap, HashMap, HashSet}; use db::VectorDatabase; @@ -88,7 +89,7 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, - Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), + Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())), language_registry, cx.clone(), ) diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 182010ca83..6842ce5c5d 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -4,7 +4,8 @@ use crate::{ semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; -use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}; +use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::providers::dummy::DummyEmbeddingProvider; use anyhow::Result; use async_trait::async_trait; use gpui::{executor::Deterministic, Task, TestAppContext}; @@ -280,7 +281,7 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -382,7 +383,7 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -466,7 +467,7 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -565,7 +566,7 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -639,7 +640,7 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -756,7 +757,7 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -909,7 +910,7 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1100,7 +1101,7 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" diff --git a/crates/zed/examples/semantic_index_eval.rs b/crates/zed/examples/semantic_index_eval.rs index 33d6b3689c..0bada47502 100644 --- a/crates/zed/examples/semantic_index_eval.rs +++ b/crates/zed/examples/semantic_index_eval.rs @@ -1,4 +1,4 @@ -use ai::embedding::OpenAIEmbeddings; +use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; use client::{self, UserStore}; use gpui::{AsyncAppContext, ModelHandle, Task}; @@ -474,7 +474,7 @@ fn main() { let semantic_index = SemanticIndex::new( fs.clone(), db_file_path, - Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), + Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())), languages.clone(), cx.clone(), ) From 2b780ee7b2e63d990400c19c002eb1fc2f7bdfd7 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 15:00:09 +0200 Subject: [PATCH 07/23] add base model to EmbeddingProvider, not yet leveraged for truncation --- crates/ai/src/embedding.rs | 3 ++ crates/ai/src/providers/dummy.rs | 35 +++++++++++++++++++ crates/ai/src/providers/open_ai/embedding.rs | 10 ++++++ crates/ai/src/providers/open_ai/model.rs | 1 + .../src/semantic_index_tests.rs | 10 ++++-- 5 files changed, 57 insertions(+), 2 deletions(-) diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 05798c3f5d..f792406c8b 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -5,6 +5,8 @@ use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; use std::time::Instant; +use crate::models::LanguageModel; + #[derive(Debug, PartialEq, Clone)] pub struct Embedding(pub Vec); @@ -66,6 +68,7 @@ impl Embedding { #[async_trait] pub trait EmbeddingProvider: Sync + Send { + fn base_model(&self) -> Box; fn is_authenticated(&self) -> bool; async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs index 8061a2ca6b..9df5547da1 100644 --- a/crates/ai/src/providers/dummy.rs +++ b/crates/ai/src/providers/dummy.rs @@ -3,10 +3,42 @@ use std::time::Instant; use crate::{ completion::CompletionRequest, embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, }; use async_trait::async_trait; use serde::Serialize; +pub struct DummyLanguageModel {} + +impl LanguageModel for DummyLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(1000) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: crate::models::TruncationDirection, + ) -> anyhow::Result { + let truncated = match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[..length] + .iter() + .collect::(), + }; + + anyhow::Ok(truncated) + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } +} + #[derive(Serialize)] pub struct DummyCompletionRequest { pub name: String, @@ -22,6 +54,9 @@ pub struct DummyEmbeddingProvider {} #[async_trait] impl EmbeddingProvider for DummyEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(DummyLanguageModel {}) + } fn is_authenticated(&self) -> bool { true } diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 35398394dc..ed028177f6 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -19,6 +19,8 @@ use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); @@ -27,6 +29,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -65,7 +68,10 @@ impl OpenAIEmbeddingProvider { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + OpenAIEmbeddingProvider { + model, client, executor, rate_limit_count_rx, @@ -131,6 +137,10 @@ impl OpenAIEmbeddingProvider { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } fn is_authenticated(&self) -> bool { OPENAI_API_KEY.as_ref().is_some() } diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs index 42523f3df4..6e306c80b9 100644 --- a/crates/ai/src/providers/open_ai/model.rs +++ b/crates/ai/src/providers/open_ai/model.rs @@ -4,6 +4,7 @@ use util::ResultExt; use crate::models::{LanguageModel, TruncationDirection}; +#[derive(Clone)] pub struct OpenAILanguageModel { name: String, bpe: Option, diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 6842ce5c5d..43779f5b6c 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -4,8 +4,11 @@ use crate::{ semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; -use ai::embedding::{Embedding, EmbeddingProvider}; -use ai::providers::dummy::DummyEmbeddingProvider; +use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel}; +use ai::{ + embedding::{Embedding, EmbeddingProvider}, + models::LanguageModel, +}; use anyhow::Result; use async_trait::async_trait; use gpui::{executor::Deterministic, Task, TestAppContext}; @@ -1282,6 +1285,9 @@ impl FakeEmbeddingProvider { #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(DummyLanguageModel {}) + } fn is_authenticated(&self) -> bool { true } From 4e90e4599973b2016370b99a5406bb1a49ca21f4 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 23 Oct 2023 14:07:45 +0200 Subject: [PATCH 08/23] move embedding truncation to base model --- crates/ai/src/embedding.rs | 1 - crates/ai/src/providers/dummy.rs | 11 +++---- crates/ai/src/providers/open_ai/embedding.rs | 28 ++++++++-------- crates/semantic_index/src/parsing.rs | 33 ++++++++++++++++--- .../src/semantic_index_tests.rs | 9 ++--- 5 files changed, 48 insertions(+), 34 deletions(-) diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index f792406c8b..4e67f44cae 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -72,7 +72,6 @@ pub trait EmbeddingProvider: Sync + Send { fn is_authenticated(&self) -> bool; async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; - fn truncate(&self, span: &str) -> (String, usize); fn rate_limit_expiration(&self) -> Option; } diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs index 9df5547da1..7eef16111d 100644 --- a/crates/ai/src/providers/dummy.rs +++ b/crates/ai/src/providers/dummy.rs @@ -23,6 +23,10 @@ impl LanguageModel for DummyLanguageModel { length: usize, direction: crate::models::TruncationDirection, ) -> anyhow::Result { + if content.len() < length { + return anyhow::Ok(content.to_string()); + } + let truncated = match direction { TruncationDirection::End => content.chars().collect::>()[..length] .iter() @@ -73,11 +77,4 @@ impl EmbeddingProvider for DummyEmbeddingProvider { fn max_tokens_per_batch(&self) -> usize { 8190 } - - fn truncate(&self, span: &str) -> (String, usize) { - let truncated = span.chars().collect::>()[..8190] - .iter() - .collect::(); - (truncated, 8190) - } } diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index ed028177f6..3689cb36f4 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -61,8 +61,6 @@ struct OpenAIEmbeddingUsage { total_tokens: usize, } -const OPENAI_INPUT_LIMIT: usize = 8190; - impl OpenAIEmbeddingProvider { pub fn new(client: Arc, executor: Arc) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); @@ -151,20 +149,20 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { fn rate_limit_expiration(&self) -> Option { *self.rate_limit_count_rx.borrow() } - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - OPENAI_BPE_TOKENIZER - .decode(tokens.clone()) - .ok() - .unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; + // fn truncate(&self, span: &str) -> (String, usize) { + // let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + // let output = if tokens.len() > OPENAI_INPUT_LIMIT { + // tokens.truncate(OPENAI_INPUT_LIMIT); + // OPENAI_BPE_TOKENIZER + // .decode(tokens.clone()) + // .ok() + // .unwrap_or_else(|| span.to_string()) + // } else { + // span.to_string() + // }; - (output, tokens.len()) - } + // (output, tokens.len()) + // } async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index f9b8bac9a4..cb15ca453b 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,4 +1,7 @@ -use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::{ + embedding::{Embedding, EmbeddingProvider}, + models::TruncationDirection, +}; use anyhow::{anyhow, Result}; use language::{Grammar, Language}; use rusqlite::{ @@ -108,7 +111,14 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); - let (document_span, token_count) = self.embedding_provider.truncate(&document_span); + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + Ok(vec![Span { range: 0..content.len(), content: document_span, @@ -131,7 +141,15 @@ impl CodeContextRetriever { ) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); - let (document_span, token_count) = self.embedding_provider.truncate(&document_span); + + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + Ok(vec![Span { range: 0..content.len(), content: document_span, @@ -222,8 +240,13 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("item", &span.content); - let (document_content, token_count) = - self.embedding_provider.truncate(&document_content); + let model = self.embedding_provider.base_model(); + let document_content = model.truncate( + &document_content, + model.capacity()?, + TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_content)?; span.content = document_content; span.token_count = token_count; diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 43779f5b6c..002dee33e3 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1291,12 +1291,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider { fn is_authenticated(&self) -> bool { true } - fn truncate(&self, span: &str) -> (String, usize) { - (span.to_string(), 1) - } - fn max_tokens_per_batch(&self) -> usize { - 200 + 1000 } fn rate_limit_expiration(&self) -> Option { @@ -1306,7 +1302,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider { async fn embed_batch(&self, spans: Vec) -> Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); - Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } } From 3447a9478c62476728f1e0131d708699dad2bcd1 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 26 Oct 2023 11:18:16 +0200 Subject: [PATCH 09/23] updated authentication for embedding provider --- crates/ai/Cargo.toml | 3 + crates/ai/src/ai.rs | 3 + crates/ai/src/auth.rs | 20 +++ crates/ai/src/embedding.rs | 8 +- crates/ai/src/prompts/base.rs | 41 +----- crates/ai/src/providers/dummy.rs | 85 ------------ crates/ai/src/providers/mod.rs | 1 - crates/ai/src/providers/open_ai/auth.rs | 33 +++++ crates/ai/src/providers/open_ai/embedding.rs | 46 ++----- crates/ai/src/providers/open_ai/mod.rs | 1 + crates/ai/src/test.rs | 123 ++++++++++++++++++ crates/assistant/src/codegen.rs | 14 +- crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/embedding_queue.rs | 16 +-- crates/semantic_index/src/semantic_index.rs | 52 +++++--- .../src/semantic_index_tests.rs | 101 +++----------- 16 files changed, 277 insertions(+), 271 deletions(-) create mode 100644 crates/ai/src/auth.rs delete mode 100644 crates/ai/src/providers/dummy.rs create mode 100644 crates/ai/src/providers/open_ai/auth.rs create mode 100644 crates/ai/src/test.rs diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index b24c4e5ece..fb49a4b515 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -8,6 +8,9 @@ publish = false path = "src/ai.rs" doctest = false +[features] +test-support = [] + [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index a3ae2fcf7f..dda22d2a1d 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,5 +1,8 @@ +pub mod auth; pub mod completion; pub mod embedding; pub mod models; pub mod prompts; pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs new file mode 100644 index 0000000000..a3ce8aece1 --- /dev/null +++ b/crates/ai/src/auth.rs @@ -0,0 +1,20 @@ +use gpui::AppContext; + +#[derive(Clone)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +pub trait CredentialProvider: Send + Sync { + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; +} + +#[derive(Clone)] +pub struct NullCredentialProvider; +impl CredentialProvider for NullCredentialProvider { + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } +} diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 8cfc901525..50f04232ab 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -7,6 +7,7 @@ use ordered_float::OrderedFloat; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; +use crate::auth::{CredentialProvider, ProviderCredential}; use crate::models::LanguageModel; #[derive(Debug, PartialEq, Clone)] @@ -71,11 +72,14 @@ impl Embedding { #[async_trait] pub trait EmbeddingProvider: Sync + Send { fn base_model(&self) -> Box; - fn retrieve_credentials(&self, cx: &AppContext) -> Option; + fn credential_provider(&self) -> Box; + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + self.credential_provider().retrieve_credentials(cx) + } async fn embed_batch( &self, spans: Vec, - api_key: Option, + credential: ProviderCredential, ) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn rate_limit_expiration(&self) -> Option; diff --git a/crates/ai/src/prompts/base.rs b/crates/ai/src/prompts/base.rs index f0ff597e63..a2106c7410 100644 --- a/crates/ai/src/prompts/base.rs +++ b/crates/ai/src/prompts/base.rs @@ -126,6 +126,7 @@ impl PromptChain { #[cfg(test)] pub(crate) mod tests { use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; use super::*; @@ -181,39 +182,7 @@ pub(crate) mod tests { } } - #[derive(Clone)] - struct DummyLanguageModel { - capacity: usize, - } - - impl LanguageModel for DummyLanguageModel { - fn name(&self) -> String { - "dummy".to_string() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - anyhow::Ok(content.chars().collect::>().len()) - } - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result { - anyhow::Ok(match direction { - TruncationDirection::End => content.chars().collect::>()[..length] - .into_iter() - .collect::(), - TruncationDirection::Start => content.chars().collect::>()[length..] - .into_iter() - .collect::(), - }) - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(self.capacity) - } - } - - let model: Arc = Arc::new(DummyLanguageModel { capacity: 100 }); + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -249,7 +218,7 @@ pub(crate) mod tests { // Testing with Truncation Off // Should ignore capacity and return all prompts - let model: Arc = Arc::new(DummyLanguageModel { capacity: 20 }); + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -286,7 +255,7 @@ pub(crate) mod tests { // Testing with Truncation Off // Should ignore capacity and return all prompts let capacity = 20; - let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let model: Arc = Arc::new(FakeLanguageModel { capacity }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -322,7 +291,7 @@ pub(crate) mod tests { // Change Ordering of Prompts Based on Priority let capacity = 120; let reserved_tokens = 10; - let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let model: Arc = Arc::new(FakeLanguageModel { capacity }); let args = PromptArguments { model: model.clone(), language_name: None, diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs deleted file mode 100644 index 2ee26488bd..0000000000 --- a/crates/ai/src/providers/dummy.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::time::Instant; - -use crate::{ - completion::CompletionRequest, - embedding::{Embedding, EmbeddingProvider}, - models::{LanguageModel, TruncationDirection}, -}; -use async_trait::async_trait; -use gpui::AppContext; -use serde::Serialize; - -pub struct DummyLanguageModel {} - -impl LanguageModel for DummyLanguageModel { - fn name(&self) -> String { - "dummy".to_string() - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(1000) - } - fn truncate( - &self, - content: &str, - length: usize, - direction: crate::models::TruncationDirection, - ) -> anyhow::Result { - if content.len() < length { - return anyhow::Ok(content.to_string()); - } - - let truncated = match direction { - TruncationDirection::End => content.chars().collect::>()[..length] - .iter() - .collect::(), - TruncationDirection::Start => content.chars().collect::>()[..length] - .iter() - .collect::(), - }; - - anyhow::Ok(truncated) - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - anyhow::Ok(content.chars().collect::>().len()) - } -} - -#[derive(Serialize)] -pub struct DummyCompletionRequest { - pub name: String, -} - -impl CompletionRequest for DummyCompletionRequest { - fn data(&self) -> serde_json::Result { - serde_json::to_string(self) - } -} - -pub struct DummyEmbeddingProvider {} - -#[async_trait] -impl EmbeddingProvider for DummyEmbeddingProvider { - fn retrieve_credentials(&self, _cx: &AppContext) -> Option { - Some("Dummy Credentials".to_string()) - } - fn base_model(&self) -> Box { - Box::new(DummyLanguageModel {}) - } - fn rate_limit_expiration(&self) -> Option { - None - } - async fn embed_batch( - &self, - spans: Vec, - api_key: Option, - ) -> anyhow::Result> { - // 1024 is the OpenAI Embeddings size for ada models. - // the model we will likely be starting with. - let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); - return Ok(vec![dummy_vec; spans.len()]); - } - - fn max_tokens_per_batch(&self) -> usize { - 8190 - } -} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs index 7a7092baf3..acd0f9d910 100644 --- a/crates/ai/src/providers/mod.rs +++ b/crates/ai/src/providers/mod.rs @@ -1,2 +1 @@ -pub mod dummy; pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/auth.rs b/crates/ai/src/providers/open_ai/auth.rs new file mode 100644 index 0000000000..c817ffea00 --- /dev/null +++ b/crates/ai/src/providers/open_ai/auth.rs @@ -0,0 +1,33 @@ +use std::env; + +use gpui::AppContext; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::providers::open_ai::OPENAI_API_URL; + +#[derive(Clone)] +pub struct OpenAICredentialProvider {} + +impl CredentialProvider for OpenAICredentialProvider { + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + }; + + if let Some(api_key) = api_key { + ProviderCredential::Credentials { api_key } + } else { + ProviderCredential::NoCredentials + } + } +} diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 805a906dda..1385b32b4d 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::{serde_json, AppContext}; +use gpui::serde_json; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; @@ -17,13 +17,13 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; -use util::ResultExt; +use crate::auth::{CredentialProvider, ProviderCredential}; use crate::embedding::{Embedding, EmbeddingProvider}; use crate::models::LanguageModel; use crate::providers::open_ai::OpenAILanguageModel; -use super::OPENAI_API_URL; +use crate::providers::open_ai::auth::OpenAICredentialProvider; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); @@ -33,6 +33,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddingProvider { model: OpenAILanguageModel, + credential_provider: OpenAICredentialProvider, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -73,6 +74,7 @@ impl OpenAIEmbeddingProvider { OpenAIEmbeddingProvider { model, + credential_provider: OpenAICredentialProvider {}, client, executor, rate_limit_count_rx, @@ -138,25 +140,17 @@ impl OpenAIEmbeddingProvider { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { - fn retrieve_credentials(&self, cx: &AppContext) -> Option { - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - api_key - } fn base_model(&self) -> Box { let model: Box = Box::new(self.model.clone()); model } + + fn credential_provider(&self) -> Box { + let credential_provider: Box = + Box::new(self.credential_provider.clone()); + credential_provider + } + fn max_tokens_per_batch(&self) -> usize { 50000 } @@ -164,25 +158,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { fn rate_limit_expiration(&self) -> Option { *self.rate_limit_count_rx.borrow() } - // fn truncate(&self, span: &str) -> (String, usize) { - // let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - // let output = if tokens.len() > OPENAI_INPUT_LIMIT { - // tokens.truncate(OPENAI_INPUT_LIMIT); - // OPENAI_BPE_TOKENIZER - // .decode(tokens.clone()) - // .ok() - // .unwrap_or_else(|| span.to_string()) - // } else { - // span.to_string() - // }; - - // (output, tokens.len()) - // } async fn embed_batch( &self, spans: Vec, - api_key: Option, + _credential: ProviderCredential, ) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 67cb2b5315..49e29fbc8c 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,3 +1,4 @@ +pub mod auth; pub mod completion; pub mod embedding; pub mod model; diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs new file mode 100644 index 0000000000..d8805bad1a --- /dev/null +++ b/crates/ai/src/test.rs @@ -0,0 +1,123 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; + +use crate::{ + auth::{CredentialProvider, NullCredentialProvider, ProviderCredential}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, + pub credential_provider: NullCredentialProvider, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + credential_provider: self.credential_provider.clone(), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + credential_provider: NullCredentialProvider {}, + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn credential_provider(&self) -> Box { + let credential_provider: Box = + Box::new(self.credential_provider.clone()); + credential_provider + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch( + &self, + spans: Vec, + _credential: ProviderCredential, + ) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e535eca144..e71b1ae2cb 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,7 +335,6 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use ai::providers::dummy::DummyCompletionRequest; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -345,9 +344,21 @@ mod tests { use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; use parking_lot::Mutex; use rand::prelude::*; + use serde::Serialize; use settings::SettingsStore; use smol::future::FutureExt; + #[derive(Serialize)] + pub struct DummyCompletionRequest { + pub name: String, + } + + impl CompletionRequest for DummyCompletionRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } + } + #[gpui::test(iterations = 10)] async fn test_transform_autoindent( cx: &mut TestAppContext, @@ -381,6 +392,7 @@ mod tests { cx, ) }); + let request = Box::new(DummyCompletionRequest { name: "test".to_string(), }); diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 1febb2af78..875440ef3f 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -42,6 +42,7 @@ sha1 = "0.10.5" ndarray = { version = "0.15.0" } [dev-dependencies] +ai = { path = "../ai", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index d57d5c7bbe..9ca6d8a0d9 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,5 +1,5 @@ use crate::{parsing::Span, JobHandle}; -use ai::embedding::EmbeddingProvider; +use ai::{auth::ProviderCredential, embedding::EmbeddingProvider}; use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; @@ -41,7 +41,7 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - api_key: Option, + provider_credential: ProviderCredential, } #[derive(Clone)] @@ -54,7 +54,7 @@ impl EmbeddingQueue { pub fn new( embedding_provider: Arc, executor: Arc, - api_key: Option, + provider_credential: ProviderCredential, ) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { @@ -64,12 +64,12 @@ impl EmbeddingQueue { pending_batch_token_count: 0, finished_files_tx, finished_files_rx, - api_key, + provider_credential, } } - pub fn set_api_key(&mut self, api_key: Option) { - self.api_key = api_key + pub fn set_credential(&mut self, credential: ProviderCredential) { + self.provider_credential = credential } pub fn push(&mut self, file: FileToEmbed) { @@ -118,7 +118,7 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - let api_key = self.api_key.clone(); + let credential = self.provider_credential.clone(); self.executor .spawn(async move { @@ -143,7 +143,7 @@ impl EmbeddingQueue { return; }; - match embedding_provider.embed_batch(spans, api_key).await { + match embedding_provider.embed_batch(spans, credential).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6863918d5d..5be3d6ccf5 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,6 +7,7 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; +use ai::auth::ProviderCredential; use ai::embedding::{Embedding, EmbeddingProvider}; use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; @@ -124,7 +125,7 @@ pub struct SemanticIndex { _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, - api_key: Option, + provider_credential: ProviderCredential, embedding_queue: Arc>, } @@ -279,18 +280,27 @@ impl SemanticIndex { } } - pub fn authenticate(&mut self, cx: &AppContext) { - if self.api_key.is_none() { - self.api_key = self.embedding_provider.retrieve_credentials(cx); - - self.embedding_queue - .lock() - .set_api_key(self.api_key.clone()); + pub fn authenticate(&mut self, cx: &AppContext) -> bool { + let credential = self.provider_credential.clone(); + match credential { + ProviderCredential::NoCredentials => { + let credential = self.embedding_provider.retrieve_credentials(cx); + self.provider_credential = credential; + } + _ => {} } + + self.embedding_queue.lock().set_credential(credential); + + self.is_authenticated() } pub fn is_authenticated(&self) -> bool { - self.api_key.is_some() + let credential = &self.provider_credential; + match credential { + &ProviderCredential::Credentials { .. } => true, + _ => false, + } } pub fn enabled(cx: &AppContext) -> bool { @@ -340,7 +350,7 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); let embedding_queue = - EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None); + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); let db = db.clone(); @@ -405,7 +415,7 @@ impl SemanticIndex { _embedding_task, _parsing_files_tasks, projects: Default::default(), - api_key: None, + provider_credential: ProviderCredential::NoCredentials, embedding_queue } })) @@ -721,13 +731,14 @@ impl SemanticIndex { let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); - let api_key = self.api_key.clone(); + let credential = self.provider_credential.clone(); cx.spawn(|this, mut cx| async move { index.await?; let t0 = Instant::now(); + let query = embedding_provider - .embed_batch(vec![query], api_key) + .embed_batch(vec![query], credential) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; @@ -945,7 +956,7 @@ impl SemanticIndex { let fs = self.fs.clone(); let db_path = self.db.path().clone(); let background = cx.background().clone(); - let api_key = self.api_key.clone(); + let credential = self.provider_credential.clone(); cx.background().spawn(async move { let db = VectorDatabase::new(fs, db_path.clone(), background).await?; let mut results = Vec::::new(); @@ -964,7 +975,7 @@ impl SemanticIndex { &mut spans, embedding_provider.as_ref(), &db, - api_key.clone(), + credential.clone(), ) .await .log_err() @@ -1008,9 +1019,8 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task> { - if self.api_key.is_none() { - self.authenticate(cx); - if self.api_key.is_none() { + if !self.is_authenticated() { + if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } @@ -1193,7 +1203,7 @@ impl SemanticIndex { spans: &mut [Span], embedding_provider: &dyn EmbeddingProvider, db: &VectorDatabase, - api_key: Option, + credential: ProviderCredential, ) -> Result<()> { let mut batch = Vec::new(); let mut batch_tokens = 0; @@ -1216,7 +1226,7 @@ impl SemanticIndex { if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), api_key.clone()) + .embed_batch(mem::take(&mut batch), credential.clone()) .await?; embeddings.extend(batch_embeddings); batch_tokens = 0; @@ -1228,7 +1238,7 @@ impl SemanticIndex { if !batch.is_empty() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), api_key) + .embed_batch(mem::take(&mut batch), credential) .await?; embeddings.extend(batch_embeddings); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 1c117c9ea2..7d5a4e22e8 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -4,14 +4,9 @@ use crate::{ semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; -use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel}; -use ai::{ - embedding::{Embedding, EmbeddingProvider}, - models::LanguageModel, -}; -use anyhow::Result; -use async_trait::async_trait; -use gpui::{executor::Deterministic, AppContext, Task, TestAppContext}; +use ai::test::FakeEmbeddingProvider; + +use gpui::{executor::Deterministic, Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use parking_lot::Mutex; use pretty_assertions::assert_eq; @@ -19,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs use rand::{rngs::StdRng, Rng}; use serde_json::json; use settings::SettingsStore; -use std::{ - path::Path, - sync::{ - atomic::{self, AtomicUsize}, - Arc, - }, - time::{Instant, SystemTime}, -}; +use std::{path::Path, sync::Arc, time::SystemTime}; use unindent::Unindent; use util::RandomCharIter; @@ -232,7 +220,11 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None); + let mut queue = EmbeddingQueue::new( + embedding_provider.clone(), + cx.background(), + ai::auth::ProviderCredential::NoCredentials, + ); for file in &files { queue.push(file.clone()); } @@ -284,7 +276,7 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -386,7 +378,7 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -470,7 +462,7 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -569,7 +561,7 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -643,7 +635,7 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -760,7 +752,7 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -913,7 +905,7 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1104,7 +1096,7 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1252,65 +1244,6 @@ async fn test_code_context_retrieval_php() { ); } -#[derive(Default)] -struct FakeEmbeddingProvider { - embedding_count: AtomicUsize, -} - -impl FakeEmbeddingProvider { - fn embedding_count(&self) -> usize { - self.embedding_count.load(atomic::Ordering::SeqCst) - } - - fn embed_sync(&self, span: &str) -> Embedding { - let mut result = vec![1.0; 26]; - for letter in span.chars() { - let letter = letter.to_ascii_lowercase(); - if letter as u32 >= 'a' as u32 { - let ix = (letter as u32) - ('a' as u32); - if ix < 26 { - result[ix as usize] += 1.0; - } - } - } - - let norm = result.iter().map(|x| x * x).sum::().sqrt(); - for x in &mut result { - *x /= norm; - } - - result.into() - } -} - -#[async_trait] -impl EmbeddingProvider for FakeEmbeddingProvider { - fn base_model(&self) -> Box { - Box::new(DummyLanguageModel {}) - } - fn retrieve_credentials(&self, _cx: &AppContext) -> Option { - Some("Fake Credentials".to_string()) - } - fn max_tokens_per_batch(&self) -> usize { - 1000 - } - - fn rate_limit_expiration(&self) -> Option { - None - } - - async fn embed_batch( - &self, - spans: Vec, - _api_key: Option, - ) -> Result> { - self.embedding_count - .fetch_add(spans.len(), atomic::Ordering::SeqCst); - - anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) - } -} - fn js_lang() -> Arc { Arc::new( Language::new( From ca82ec8e8e2484099de3020233804671387ba9de Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 26 Oct 2023 14:05:55 +0200 Subject: [PATCH 10/23] fixed truncation error in fake language model --- crates/ai/src/auth.rs | 2 +- crates/ai/src/test.rs | 4 ++++ crates/semantic_index/src/embedding_queue.rs | 2 +- crates/semantic_index/src/semantic_index.rs | 5 ++++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs index a3ce8aece1..c188c30797 100644 --- a/crates/ai/src/auth.rs +++ b/crates/ai/src/auth.rs @@ -1,6 +1,6 @@ use gpui::AppContext; -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum ProviderCredential { Credentials { api_key: String }, NoCredentials, diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index d8805bad1a..bc143e3c21 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -29,6 +29,10 @@ impl LanguageModel for FakeLanguageModel { length: usize, direction: TruncationDirection, ) -> anyhow::Result { + if length > self.count_tokens(content)? { + return anyhow::Ok(content.to_string()); + } + anyhow::Ok(match direction { TruncationDirection::End => content.chars().collect::>()[..length] .into_iter() diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 9ca6d8a0d9..299aa328b5 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -69,7 +69,7 @@ impl EmbeddingQueue { } pub fn set_credential(&mut self, credential: ProviderCredential) { - self.provider_credential = credential + self.provider_credential = credential; } pub fn push(&mut self, file: FileToEmbed) { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 5be3d6ccf5..f420e0503b 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -291,7 +291,6 @@ impl SemanticIndex { } self.embedding_queue.lock().set_credential(credential); - self.is_authenticated() } @@ -299,6 +298,7 @@ impl SemanticIndex { let credential = &self.provider_credential; match credential { &ProviderCredential::Credentials { .. } => true, + &ProviderCredential::NotNeeded => true, _ => false, } } @@ -1020,11 +1020,14 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task> { if !self.is_authenticated() { + println!("Authenticating"); if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } + println!("SHOULD NOW BE AUTHENTICATED"); + if !self.projects.contains_key(&project.downgrade()) { let subscription = cx.subscribe(&project, |this, project, event, cx| match event { project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { From 6c8bb4b05e62aab88d5487cecb215c2fe8863a49 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 27 Oct 2023 08:33:35 +0200 Subject: [PATCH 11/23] ensure OpenAIEmbeddingProvider is using the provider credentials --- crates/ai/src/providers/open_ai/embedding.rs | 11 ++++++----- crates/semantic_index/src/embedding_queue.rs | 2 +- crates/semantic_index/src/semantic_index.rs | 17 ++++++----------- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 1385b32b4d..9806877660 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -162,14 +162,15 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { async fn embed_batch( &self, spans: Vec, - _credential: ProviderCredential, + credential: ProviderCredential, ) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; + let api_key = match credential { + ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key), + _ => Err(anyhow!("no api key provided")), + }?; let mut request_number = 0; let mut rate_limiting = false; @@ -178,7 +179,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { while request_number < MAX_RETRIES { response = self .send_request( - api_key, + &api_key, spans.iter().map(|x| &**x).collect(), request_timeout, ) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 299aa328b5..6f792c78e2 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -41,7 +41,7 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - provider_credential: ProviderCredential, + pub provider_credential: ProviderCredential, } #[derive(Clone)] diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index f420e0503b..7fb5f749b4 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -281,15 +281,13 @@ impl SemanticIndex { } pub fn authenticate(&mut self, cx: &AppContext) -> bool { - let credential = self.provider_credential.clone(); - match credential { - ProviderCredential::NoCredentials => { - let credential = self.embedding_provider.retrieve_credentials(cx); - self.provider_credential = credential; - } - _ => {} - } + let existing_credential = self.provider_credential.clone(); + let credential = match existing_credential { + ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx), + _ => existing_credential, + }; + self.provider_credential = credential.clone(); self.embedding_queue.lock().set_credential(credential); self.is_authenticated() } @@ -1020,14 +1018,11 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task> { if !self.is_authenticated() { - println!("Authenticating"); if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } - println!("SHOULD NOW BE AUTHENTICATED"); - if !self.projects.contains_key(&project.downgrade()) { let subscription = cx.subscribe(&project, |this, project, event, cx| match event { project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { From ec9d79b6fec4e90f34367bd3a855ef11c58f75fd Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 27 Oct 2023 08:51:30 +0200 Subject: [PATCH 12/23] add concept of LanguageModel to CompletionProvider --- crates/ai/src/completion.rs | 3 +++ crates/ai/src/providers/open_ai/completion.rs | 21 ++++++++++++++++--- crates/ai/src/providers/open_ai/embedding.rs | 1 - crates/assistant/src/assistant_panel.rs | 1 + crates/assistant/src/codegen.rs | 5 +++++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index ba89c869d2..da9ebd5a1d 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,11 +1,14 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; +use crate::models::LanguageModel; + pub trait CompletionRequest: Send + Sync { fn data(&self) -> serde_json::Result; } pub trait CompletionProvider { + fn base_model(&self) -> Box; fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 95ed13c0dd..20f72c0ff7 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -12,7 +12,12 @@ use std::{ sync::Arc, }; -use crate::completion::{CompletionProvider, CompletionRequest}; +use crate::{ + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use super::OpenAILanguageModel; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; @@ -180,17 +185,27 @@ pub async fn stream_completion( } pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, api_key: String, executor: Arc, } impl OpenAICompletionProvider { - pub fn new(api_key: String, executor: Arc) -> Self { - Self { api_key, executor } + pub fn new(model_name: &str, api_key: String, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + Self { + model, + api_key, + executor, + } } } impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 9806877660..64f568da1a 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -26,7 +26,6 @@ use crate::providers::open_ai::OpenAILanguageModel; use crate::providers::open_ai::auth::OpenAICredentialProvider; lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ec16c8fd04..c899465ed2 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -328,6 +328,7 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( + "gpt-4", api_key, cx.background().clone(), )); diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e71b1ae2cb..33adb2e570 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,6 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; + use ai::{models::LanguageModel, test::FakeLanguageModel}; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -638,6 +639,10 @@ mod tests { } impl CompletionProvider for TestCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } fn complete( &self, _prompt: Box, From 7af77b1cf95da45314092aa35f7bcc04fa4fd3bc Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 27 Oct 2023 12:26:01 +0200 Subject: [PATCH 13/23] moved TestCompletionProvider into ai --- crates/ai/src/test.rs | 39 +++++++++++++++++++++++++++++++++ crates/assistant/Cargo.toml | 1 + crates/assistant/src/codegen.rs | 38 +------------------------------- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index bc143e3c21..2c78027b62 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -4,9 +4,12 @@ use std::{ }; use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use parking_lot::Mutex; use crate::{ auth::{CredentialProvider, NullCredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, embedding::{Embedding, EmbeddingProvider}, models::{LanguageModel, TruncationDirection}, }; @@ -125,3 +128,39 @@ impl EmbeddingProvider for FakeEmbeddingProvider { anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } } + +pub struct TestCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl TestCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +impl CompletionProvider for TestCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } +} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 9cfdd3301a..6b0ce659e3 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -44,6 +44,7 @@ tiktoken-rs = "0.5" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +ai = { path = "../ai", features = ["test-support"]} ctor.workspace = true env_logger.workspace = true diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 33adb2e570..3516fc3708 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,7 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use ai::{models::LanguageModel, test::FakeLanguageModel}; + use ai::test::TestCompletionProvider; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -617,42 +617,6 @@ mod tests { } } - struct TestCompletionProvider { - last_completion_tx: Mutex>>, - } - - impl TestCompletionProvider { - fn new() -> Self { - Self { - last_completion_tx: Mutex::new(None), - } - } - - fn send_completion(&self, completion: impl Into) { - let mut tx = self.last_completion_tx.lock(); - tx.as_mut().unwrap().try_send(completion.into()).unwrap(); - } - - fn finish_completion(&self) { - self.last_completion_tx.lock().take().unwrap(); - } - } - - impl CompletionProvider for TestCompletionProvider { - fn base_model(&self) -> Box { - let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); - model - } - fn complete( - &self, - _prompt: Box, - ) -> BoxFuture<'static, Result>>> { - let (tx, rx) = mpsc::channel(1); - *self.last_completion_tx.lock() = Some(tx); - async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() - } - } - fn rust_lang() -> Language { Language::new( LanguageConfig { From 558f54c424a1b0f7ccaa317af62b50fd5c467fc0 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sat, 28 Oct 2023 16:35:43 -0400 Subject: [PATCH 14/23] added credential provider to completion provider --- crates/ai/src/completion.rs | 10 +++++++++- crates/ai/src/providers/open_ai/completion.rs | 10 +++++++++- crates/ai/src/providers/open_ai/embedding.rs | 1 - crates/ai/src/test.rs | 3 +++ crates/assistant/src/codegen.rs | 7 +------ 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index da9ebd5a1d..5b9bad4870 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,7 +1,11 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; +use gpui::AppContext; -use crate::models::LanguageModel; +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + models::LanguageModel, +}; pub trait CompletionRequest: Send + Sync { fn data(&self) -> serde_json::Result; @@ -9,6 +13,10 @@ pub trait CompletionRequest: Send + Sync { pub trait CompletionProvider { fn base_model(&self) -> Box; + fn credential_provider(&self) -> Box; + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + self.credential_provider().retrieve_credentials(cx) + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 20f72c0ff7..9c9d205ff7 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -13,11 +13,12 @@ use std::{ }; use crate::{ + auth::CredentialProvider, completion::{CompletionProvider, CompletionRequest}, models::LanguageModel, }; -use super::OpenAILanguageModel; +use super::{auth::OpenAICredentialProvider, OpenAILanguageModel}; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; @@ -186,6 +187,7 @@ pub async fn stream_completion( pub struct OpenAICompletionProvider { model: OpenAILanguageModel, + credential_provider: OpenAICredentialProvider, api_key: String, executor: Arc, } @@ -193,8 +195,10 @@ pub struct OpenAICompletionProvider { impl OpenAICompletionProvider { pub fn new(model_name: &str, api_key: String, executor: Arc) -> Self { let model = OpenAILanguageModel::load(model_name); + let credential_provider = OpenAICredentialProvider {}; Self { model, + credential_provider, api_key, executor, } @@ -206,6 +210,10 @@ impl CompletionProvider for OpenAICompletionProvider { let model: Box = Box::new(self.model.clone()); model } + fn credential_provider(&self) -> Box { + let provider: Box = Box::new(self.credential_provider.clone()); + provider + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 64f568da1a..dafc94580d 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -11,7 +11,6 @@ use parking_lot::Mutex; use parse_duration::parse; use postage::watch; use serde::{Deserialize, Serialize}; -use std::env; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, Instant}; diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index 2c78027b62..b8f99af400 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -155,6 +155,9 @@ impl CompletionProvider for TestCompletionProvider { let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); model } + fn credential_provider(&self) -> Box { + Box::new(NullCredentialProvider {}) + } fn complete( &self, _prompt: Box, diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 3516fc3708..7f4c95f655 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -336,18 +336,13 @@ fn strip_markdown_codeblock( mod tests { use super::*; use ai::test::TestCompletionProvider; - use futures::{ - future::BoxFuture, - stream::{self, BoxStream}, - }; + use futures::stream::{self}; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; - use parking_lot::Mutex; use rand::prelude::*; use serde::Serialize; use settings::SettingsStore; - use smol::future::FutureExt; #[derive(Serialize)] pub struct DummyCompletionRequest { From 1e8b23d8fb9cf231de56ed25ebd56ea04190fc55 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sat, 28 Oct 2023 18:16:45 -0400 Subject: [PATCH 15/23] replace api_key with ProviderCredential throughout the AssistantPanel --- crates/ai/src/auth.rs | 4 + crates/ai/src/completion.rs | 6 + crates/ai/src/providers/open_ai/auth.rs | 13 + crates/ai/src/providers/open_ai/completion.rs | 24 +- crates/assistant/src/assistant_panel.rs | 282 +++++++++++------- 5 files changed, 208 insertions(+), 121 deletions(-) diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs index c188c30797..cb3f2beabb 100644 --- a/crates/ai/src/auth.rs +++ b/crates/ai/src/auth.rs @@ -9,6 +9,8 @@ pub enum ProviderCredential { pub trait CredentialProvider: Send + Sync { fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential); + fn delete_credentials(&self, cx: &AppContext); } #[derive(Clone)] @@ -17,4 +19,6 @@ impl CredentialProvider for NullCredentialProvider { fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { ProviderCredential::NotNeeded } + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {} + fn delete_credentials(&self, cx: &AppContext) {} } diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 5b9bad4870..6a2806a5cb 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -17,6 +17,12 @@ pub trait CompletionProvider { fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { self.credential_provider().retrieve_credentials(cx) } + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + self.credential_provider().save_credentials(cx, credential); + } + fn delete_credentials(&self, cx: &AppContext) { + self.credential_provider().delete_credentials(cx); + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/auth.rs b/crates/ai/src/providers/open_ai/auth.rs index c817ffea00..7cb51ab449 100644 --- a/crates/ai/src/providers/open_ai/auth.rs +++ b/crates/ai/src/providers/open_ai/auth.rs @@ -30,4 +30,17 @@ impl CredentialProvider for OpenAICredentialProvider { ProviderCredential::NoCredentials } } + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + } } diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 9c9d205ff7..febe491123 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -13,7 +13,7 @@ use std::{ }; use crate::{ - auth::CredentialProvider, + auth::{CredentialProvider, ProviderCredential}, completion::{CompletionProvider, CompletionRequest}, models::LanguageModel, }; @@ -102,10 +102,17 @@ pub struct OpenAIResponseStreamEvent { } pub async fn stream_completion( - api_key: String, + credential: ProviderCredential, executor: Arc, request: Box, ) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + let (tx, rx) = futures::channel::mpsc::unbounded::>(); let json_data = request.data()?; @@ -188,18 +195,22 @@ pub async fn stream_completion( pub struct OpenAICompletionProvider { model: OpenAILanguageModel, credential_provider: OpenAICredentialProvider, - api_key: String, + credential: ProviderCredential, executor: Arc, } impl OpenAICompletionProvider { - pub fn new(model_name: &str, api_key: String, executor: Arc) -> Self { + pub fn new( + model_name: &str, + credential: ProviderCredential, + executor: Arc, + ) -> Self { let model = OpenAILanguageModel::load(model_name); let credential_provider = OpenAICredentialProvider {}; Self { model, credential_provider, - api_key, + credential, executor, } } @@ -218,7 +229,8 @@ impl CompletionProvider for OpenAICompletionProvider { &self, prompt: Box, ) -> BoxFuture<'static, Result>>> { - let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); + let credential = self.credential.clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); async move { let response = request.await?; let stream = response diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index c899465ed2..f9187b8785 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -7,7 +7,8 @@ use crate::{ }; use ai::{ - completion::CompletionRequest, + auth::ProviderCredential, + completion::{CompletionProvider, CompletionRequest}, providers::open_ai::{ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, }, @@ -100,8 +101,8 @@ pub fn init(cx: &mut AppContext) { cx.capture_action(ConversationEditor::copy); cx.add_action(ConversationEditor::split); cx.capture_action(ConversationEditor::cycle_message_role); - cx.add_action(AssistantPanel::save_api_key); - cx.add_action(AssistantPanel::reset_api_key); + cx.add_action(AssistantPanel::save_credentials); + cx.add_action(AssistantPanel::reset_credentials); cx.add_action(AssistantPanel::toggle_zoom); cx.add_action(AssistantPanel::deploy); cx.add_action(AssistantPanel::select_next_match); @@ -143,7 +144,8 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - api_key: Rc>>, + credential: Rc>, + completion_provider: Box, api_key_editor: Option>, has_read_credentials: bool, languages: Arc, @@ -205,6 +207,12 @@ impl AssistantPanel { }); let semantic_index = SemanticIndex::global(cx); + // Defaulting currently to GPT4, allow for this to be set via config. + let completion_provider = Box::new(OpenAICompletionProvider::new( + "gpt-4", + ProviderCredential::NoCredentials, + cx.background().clone(), + )); let mut this = Self { workspace: workspace_handle, @@ -216,7 +224,8 @@ impl AssistantPanel { zoomed: false, has_focus: false, toolbar, - api_key: Rc::new(RefCell::new(None)), + credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)), + completion_provider, api_key_editor: None, has_read_credentials: false, languages: workspace.app_state().languages.clone(), @@ -257,10 +266,7 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this - .update(cx, |assistant, cx| assistant.load_api_key(cx)) - .is_some() - { + if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) { this } else { workspace.focus_panel::(cx); @@ -292,12 +298,7 @@ impl AssistantPanel { cx: &mut ViewContext, project: &ModelHandle, ) { - let api_key = if let Some(api_key) = self.api_key.borrow().clone() { - api_key - } else { - return; - }; - + let credential = self.credential.borrow().clone(); let selection = editor.read(cx).selections.newest_anchor().clone(); if selection.start.excerpt_id() != selection.end.excerpt_id() { return; @@ -329,7 +330,7 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( "gpt-4", - api_key, + credential, cx.background().clone(), )); @@ -816,7 +817,7 @@ impl AssistantPanel { fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { ConversationEditor::new( - self.api_key.clone(), + self.credential.clone(), self.languages.clone(), self.fs.clone(), self.workspace.clone(), @@ -875,17 +876,20 @@ impl AssistantPanel { } } - fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { if let Some(api_key) = self .api_key_editor .as_ref() .map(|editor| editor.read(cx).text(cx)) { if !api_key.is_empty() { - cx.platform() - .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) - .log_err(); - *self.api_key.borrow_mut() = Some(api_key); + let credential = ProviderCredential::Credentials { + api_key: api_key.clone(), + }; + self.completion_provider + .save_credentials(cx, credential.clone()); + *self.credential.borrow_mut() = credential; + self.api_key_editor.take(); cx.focus_self(); cx.notify(); @@ -895,9 +899,9 @@ impl AssistantPanel { } } - fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext) { - cx.platform().delete_credentials(OPENAI_API_URL).log_err(); - self.api_key.take(); + fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { + self.completion_provider.delete_credentials(cx); + *self.credential.borrow_mut() = ProviderCredential::NoCredentials; self.api_key_editor = Some(build_api_key_editor(cx)); cx.focus_self(); cx.notify(); @@ -1156,13 +1160,19 @@ impl AssistantPanel { let fs = self.fs.clone(); let workspace = self.workspace.clone(); - let api_key = self.api_key.clone(); + let credential = self.credential.clone(); let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { let saved_conversation = fs.load(&path).await?; let saved_conversation = serde_json::from_str(&saved_conversation)?; let conversation = cx.add_model(|cx| { - Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx) + Conversation::deserialize( + saved_conversation, + path.clone(), + credential, + languages, + cx, + ) }); this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened @@ -1186,30 +1196,39 @@ impl AssistantPanel { .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } - fn load_api_key(&mut self, cx: &mut ViewContext) -> Option { - if self.api_key.borrow().is_none() && !self.has_read_credentials { - self.has_read_credentials = true; - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - if let Some(api_key) = api_key { - *self.api_key.borrow_mut() = Some(api_key); - } else if self.api_key_editor.is_none() { - self.api_key_editor = Some(build_api_key_editor(cx)); - cx.notify(); + fn has_credentials(&mut self, cx: &mut ViewContext) -> bool { + let credential = self.load_credentials(cx); + match credential { + ProviderCredential::Credentials { .. } => true, + ProviderCredential::NotNeeded => true, + ProviderCredential::NoCredentials => false, + } + } + + fn load_credentials(&mut self, cx: &mut ViewContext) -> ProviderCredential { + let existing_credential = self.credential.clone(); + let existing_credential = existing_credential.borrow().clone(); + match existing_credential { + ProviderCredential::NoCredentials => { + if !self.has_read_credentials { + self.has_read_credentials = true; + let retrieved_credentials = self.completion_provider.retrieve_credentials(cx); + + match retrieved_credentials { + ProviderCredential::NoCredentials {} => { + self.api_key_editor = Some(build_api_key_editor(cx)); + cx.notify(); + } + _ => { + *self.credential.borrow_mut() = retrieved_credentials; + } + } + } } + _ => {} } - self.api_key.borrow().clone() + self.credential.borrow().clone() } } @@ -1394,7 +1413,7 @@ impl Panel for AssistantPanel { fn set_active(&mut self, active: bool, cx: &mut ViewContext) { if active { - self.load_api_key(cx); + self.load_credentials(cx); if self.editors.is_empty() { self.new_conversation(cx); @@ -1459,7 +1478,7 @@ struct Conversation { token_count: Option, max_token_count: usize, pending_token_count: Task>, - api_key: Rc>>, + credential: Rc>, pending_save: Task>, path: Option, _subscriptions: Vec, @@ -1471,7 +1490,8 @@ impl Entity for Conversation { impl Conversation { fn new( - api_key: Rc>>, + credential: Rc>, + language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1512,7 +1532,7 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, - api_key, + credential, buffer, }; let message = MessageAnchor { @@ -1559,7 +1579,7 @@ impl Conversation { fn deserialize( saved_conversation: SavedConversation, path: PathBuf, - api_key: Rc>>, + credential: Rc>, language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1614,7 +1634,7 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), - api_key, + credential, buffer, }; this.count_remaining_tokens(cx); @@ -1736,9 +1756,13 @@ impl Conversation { } if should_assist { - let Some(api_key) = self.api_key.borrow().clone() else { - return Default::default(); - }; + let credential = self.credential.borrow().clone(); + match credential { + ProviderCredential::NoCredentials => { + return Default::default(); + } + _ => {} + } let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), @@ -1752,7 +1776,7 @@ impl Conversation { temperature: 1.0, }); - let stream = stream_completion(api_key, cx.background().clone(), request); + let stream = stream_completion(credential, cx.background().clone(), request); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -2018,57 +2042,62 @@ impl Conversation { fn summarize(&mut self, cx: &mut ModelContext) { if self.message_anchors.len() >= 2 && self.summary.is_none() { - let api_key = self.api_key.borrow().clone(); - if let Some(api_key) = api_key { - let messages = self - .messages(cx) - .take(2) - .map(|message| message.to_open_ai_message(self.buffer.read(cx))) - .chain(Some(RequestMessage { - role: Role::User, - content: - "Summarize the conversation into a short title without punctuation" - .into(), - })); - let request: Box = Box::new(OpenAIRequest { - model: self.model.full_name().to_string(), - messages: messages.collect(), - stream: true, - stop: vec![], - temperature: 1.0, - }); + let credential = self.credential.borrow().clone(); - let stream = stream_completion(api_key, cx.background().clone(), request); - self.pending_summary = cx.spawn(|this, mut cx| { - async move { - let mut messages = stream.await?; - - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - let text = choice.delta.content.unwrap_or_default(); - this.update(&mut cx, |this, cx| { - this.summary - .get_or_insert(Default::default()) - .text - .push_str(&text); - cx.emit(ConversationEvent::SummaryChanged); - }); - } - } - - this.update(&mut cx, |this, cx| { - if let Some(summary) = this.summary.as_mut() { - summary.done = true; - cx.emit(ConversationEvent::SummaryChanged); - } - }); - - anyhow::Ok(()) - } - .log_err() - }); + match credential { + ProviderCredential::NoCredentials => { + return; + } + _ => {} } + + let messages = self + .messages(cx) + .take(2) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::User, + content: "Summarize the conversation into a short title without punctuation" + .into(), + })); + let request: Box = Box::new(OpenAIRequest { + model: self.model.full_name().to_string(), + messages: messages.collect(), + stream: true, + stop: vec![], + temperature: 1.0, + }); + + let stream = stream_completion(credential, cx.background().clone(), request); + self.pending_summary = cx.spawn(|this, mut cx| { + async move { + let mut messages = stream.await?; + + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + let text = choice.delta.content.unwrap_or_default(); + this.update(&mut cx, |this, cx| { + this.summary + .get_or_insert(Default::default()) + .text + .push_str(&text); + cx.emit(ConversationEvent::SummaryChanged); + }); + } + } + + this.update(&mut cx, |this, cx| { + if let Some(summary) = this.summary.as_mut() { + summary.done = true; + cx.emit(ConversationEvent::SummaryChanged); + } + }); + + anyhow::Ok(()) + } + .log_err() + }); } } @@ -2229,13 +2258,13 @@ struct ConversationEditor { impl ConversationEditor { fn new( - api_key: Rc>>, + credential: Rc>, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, cx: &mut ViewContext, ) -> Self { - let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx)); + let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx)); Self::for_conversation(conversation, fs, workspace, cx) } @@ -3431,7 +3460,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry, + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3559,7 +3594,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry, + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3655,7 +3696,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry, + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3737,8 +3784,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = - cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry.clone(), + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_0 = conversation.read(cx).message_anchors[0].id; let message_1 = conversation.update(cx, |conversation, cx| { @@ -3775,7 +3827,7 @@ mod tests { Conversation::deserialize( conversation.read(cx).serialize(cx), Default::default(), - Default::default(), + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), registry.clone(), cx, ) From 34747bbbbc9190457eab894f2b86bdae07385312 Mon Sep 17 00:00:00 2001 From: "Joseph T. Lyons" Date: Sun, 29 Oct 2023 13:47:02 -0500 Subject: [PATCH 16/23] Do not call `scroll_to` twice --- crates/editor/src/editor.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index bfb87afff2..701a6882a0 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -967,7 +967,6 @@ impl CompletionsMenu { self.selected_item -= 1; } else { self.selected_item = self.matches.len() - 1; - self.list.scroll_to(ScrollTarget::Show(self.selected_item)); } self.list.scroll_to(ScrollTarget::Show(self.selected_item)); self.attempt_resolve_selected_completion_documentation(project, cx); @@ -1538,7 +1537,6 @@ impl CodeActionsMenu { self.selected_item -= 1; } else { self.selected_item = self.actions.len() - 1; - self.list.scroll_to(ScrollTarget::Show(self.selected_item)); } self.list.scroll_to(ScrollTarget::Show(self.selected_item)); cx.notify(); From dd89b2e6d497d0d4012f880aed8c8ea15177df6e Mon Sep 17 00:00:00 2001 From: "Joseph T. Lyons" Date: Sun, 29 Oct 2023 13:54:32 -0500 Subject: [PATCH 17/23] Pull duplicate call out of `if`-`else` block --- crates/editor/src/editor.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 701a6882a0..4e449bb7f7 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1545,11 +1545,10 @@ impl CodeActionsMenu { fn select_next(&mut self, cx: &mut ViewContext) { if self.selected_item + 1 < self.actions.len() { self.selected_item += 1; - self.list.scroll_to(ScrollTarget::Show(self.selected_item)); } else { self.selected_item = 0; - self.list.scroll_to(ScrollTarget::Show(self.selected_item)); } + self.list.scroll_to(ScrollTarget::Show(self.selected_item)); cx.notify(); } From 96bbb5cdea41d537324294bf025cfc6cca7ea51e Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Mon, 30 Oct 2023 11:14:00 +0200 Subject: [PATCH 18/23] Properly log prettier paths --- crates/prettier/src/prettier.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/prettier/src/prettier.rs b/crates/prettier/src/prettier.rs index 79fef40908..53e3101a3b 100644 --- a/crates/prettier/src/prettier.rs +++ b/crates/prettier/src/prettier.rs @@ -93,7 +93,7 @@ impl Prettier { ) })?; (worktree_root_data.unwrap_or_else(|| { - panic!("cannot query prettier for non existing worktree root at {worktree_root_data:?}") + panic!("cannot query prettier for non existing worktree root at {worktree_root:?}") }), None) } else { let full_starting_path = worktree_root.join(&starting_path.starting_path); @@ -106,7 +106,7 @@ impl Prettier { })?; ( worktree_root_data.unwrap_or_else(|| { - panic!("cannot query prettier for non existing worktree root at {worktree_root_data:?}") + panic!("cannot query prettier for non existing worktree root at {worktree_root:?}") }), start_path_data, ) From 249bec3cac269909d96226e5732ea36ce8b3569d Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Mon, 30 Oct 2023 12:13:34 +0200 Subject: [PATCH 19/23] Do not panic on prettier search --- crates/prettier/src/prettier.rs | 120 +++++++++++++++----------------- 1 file changed, 55 insertions(+), 65 deletions(-) diff --git a/crates/prettier/src/prettier.rs b/crates/prettier/src/prettier.rs index 53e3101a3b..6784dba7dc 100644 --- a/crates/prettier/src/prettier.rs +++ b/crates/prettier/src/prettier.rs @@ -81,77 +81,67 @@ impl Prettier { if worktree_root != starting_path.worktree_root_path.as_ref() { vec![worktree_root] } else { - let (worktree_root_metadata, start_path_metadata) = if starting_path - .starting_path - .as_ref() - == Path::new("") - { - let worktree_root_data = - fs.metadata(&worktree_root).await.with_context(|| { - format!( - "FS metadata fetch for worktree root path {worktree_root:?}", - ) - })?; - (worktree_root_data.unwrap_or_else(|| { - panic!("cannot query prettier for non existing worktree root at {worktree_root:?}") - }), None) + let worktree_root_metadata = fs + .metadata(&worktree_root) + .await + .with_context(|| { + format!("FS metadata fetch for worktree root path {worktree_root:?}",) + })? + .with_context(|| { + format!("empty FS metadata for worktree root at {worktree_root:?}") + })?; + if starting_path.starting_path.as_ref() == Path::new("") { + anyhow::ensure!( + !worktree_root_metadata.is_dir, + "For empty start path, worktree root should not be a directory {starting_path:?}" + ); + anyhow::ensure!( + !worktree_root_metadata.is_symlink, + "For empty start path, worktree root should not be a symlink {starting_path:?}" + ); + worktree_root + .parent() + .map(|path| vec![path.to_path_buf()]) + .unwrap_or_default() } else { let full_starting_path = worktree_root.join(&starting_path.starting_path); - let (worktree_root_data, start_path_data) = futures::try_join!( - fs.metadata(&worktree_root), - fs.metadata(&full_starting_path), - ) - .with_context(|| { - format!("FS metadata fetch for starting path {full_starting_path:?}",) - })?; - ( - worktree_root_data.unwrap_or_else(|| { - panic!("cannot query prettier for non existing worktree root at {worktree_root:?}") - }), - start_path_data, - ) - }; + let start_path_metadata = fs + .metadata(&full_starting_path) + .await + .with_context(|| { + format!( + "FS metadata fetch for starting path {full_starting_path:?}" + ) + })? + .with_context(|| { + format!( + "empty FS metadata for starting path {full_starting_path:?}" + ) + })?; - match start_path_metadata { - Some(start_path_metadata) => { - anyhow::ensure!(worktree_root_metadata.is_dir, - "For non-empty start path, worktree root {starting_path:?} should be a directory"); - anyhow::ensure!( - !start_path_metadata.is_dir, - "For non-empty start path, it should not be a directory {starting_path:?}" - ); - anyhow::ensure!( - !start_path_metadata.is_symlink, - "For non-empty start path, it should not be a symlink {starting_path:?}" - ); + anyhow::ensure!(worktree_root_metadata.is_dir, + "For non-empty start path, worktree root {starting_path:?} should be a directory"); + anyhow::ensure!( + !start_path_metadata.is_dir, + "For non-empty start path, it should not be a directory {starting_path:?}" + ); + anyhow::ensure!( + !start_path_metadata.is_symlink, + "For non-empty start path, it should not be a symlink {starting_path:?}" + ); - let file_to_format = starting_path.starting_path.as_ref(); - let mut paths_to_check = VecDeque::from(vec![worktree_root.clone()]); - let mut current_path = worktree_root; - for path_component in file_to_format.components().into_iter() { - current_path = current_path.join(path_component); - paths_to_check.push_front(current_path.clone()); - if path_component.as_os_str().to_string_lossy() == "node_modules" { - break; - } + let file_to_format = starting_path.starting_path.as_ref(); + let mut paths_to_check = VecDeque::from(vec![worktree_root.clone()]); + let mut current_path = worktree_root; + for path_component in file_to_format.components().into_iter() { + current_path = current_path.join(path_component); + paths_to_check.push_front(current_path.clone()); + if path_component.as_os_str().to_string_lossy() == "node_modules" { + break; } - paths_to_check.pop_front(); // last one is the file itself or node_modules, skip it - Vec::from(paths_to_check) - } - None => { - anyhow::ensure!( - !worktree_root_metadata.is_dir, - "For empty start path, worktree root should not be a directory {starting_path:?}" - ); - anyhow::ensure!( - !worktree_root_metadata.is_symlink, - "For empty start path, worktree root should not be a symlink {starting_path:?}" - ); - worktree_root - .parent() - .map(|path| vec![path.to_path_buf()]) - .unwrap_or_default() } + paths_to_check.pop_front(); // last one is the file itself or node_modules, skip it + Vec::from(paths_to_check) } } } From b46a4b56808f7c3521250bef6ee9e4f4389b6973 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Mon, 30 Oct 2023 12:07:11 +0200 Subject: [PATCH 20/23] Be more lenient when searching for prettier instance Do not check FS for existence (we'll error when start running prettier), simplify the code for looking it up --- crates/prettier/src/prettier.rs | 62 ++++++--------------------------- 1 file changed, 10 insertions(+), 52 deletions(-) diff --git a/crates/prettier/src/prettier.rs b/crates/prettier/src/prettier.rs index 6784dba7dc..7517b4ee43 100644 --- a/crates/prettier/src/prettier.rs +++ b/crates/prettier/src/prettier.rs @@ -67,80 +67,38 @@ impl Prettier { starting_path: Option, fs: Arc, ) -> anyhow::Result { + fn is_node_modules(path_component: &std::path::Component<'_>) -> bool { + path_component.as_os_str().to_string_lossy() == "node_modules" + } + let paths_to_check = match starting_path.as_ref() { Some(starting_path) => { let worktree_root = starting_path .worktree_root_path .components() .into_iter() - .take_while(|path_component| { - path_component.as_os_str().to_string_lossy() != "node_modules" - }) + .take_while(|path_component| !is_node_modules(path_component)) .collect::(); - if worktree_root != starting_path.worktree_root_path.as_ref() { vec![worktree_root] } else { - let worktree_root_metadata = fs - .metadata(&worktree_root) - .await - .with_context(|| { - format!("FS metadata fetch for worktree root path {worktree_root:?}",) - })? - .with_context(|| { - format!("empty FS metadata for worktree root at {worktree_root:?}") - })?; if starting_path.starting_path.as_ref() == Path::new("") { - anyhow::ensure!( - !worktree_root_metadata.is_dir, - "For empty start path, worktree root should not be a directory {starting_path:?}" - ); - anyhow::ensure!( - !worktree_root_metadata.is_symlink, - "For empty start path, worktree root should not be a symlink {starting_path:?}" - ); worktree_root .parent() .map(|path| vec![path.to_path_buf()]) .unwrap_or_default() } else { - let full_starting_path = worktree_root.join(&starting_path.starting_path); - let start_path_metadata = fs - .metadata(&full_starting_path) - .await - .with_context(|| { - format!( - "FS metadata fetch for starting path {full_starting_path:?}" - ) - })? - .with_context(|| { - format!( - "empty FS metadata for starting path {full_starting_path:?}" - ) - })?; - - anyhow::ensure!(worktree_root_metadata.is_dir, - "For non-empty start path, worktree root {starting_path:?} should be a directory"); - anyhow::ensure!( - !start_path_metadata.is_dir, - "For non-empty start path, it should not be a directory {starting_path:?}" - ); - anyhow::ensure!( - !start_path_metadata.is_symlink, - "For non-empty start path, it should not be a symlink {starting_path:?}" - ); - let file_to_format = starting_path.starting_path.as_ref(); - let mut paths_to_check = VecDeque::from(vec![worktree_root.clone()]); + let mut paths_to_check = VecDeque::new(); let mut current_path = worktree_root; for path_component in file_to_format.components().into_iter() { - current_path = current_path.join(path_component); - paths_to_check.push_front(current_path.clone()); - if path_component.as_os_str().to_string_lossy() == "node_modules" { + let new_path = current_path.join(path_component); + let old_path = std::mem::replace(&mut current_path, new_path); + paths_to_check.push_front(old_path); + if is_node_modules(&path_component) { break; } } - paths_to_check.pop_front(); // last one is the file itself or node_modules, skip it Vec::from(paths_to_check) } } From a2c3971ad6202bcec51dc0f36ef13497e94d1597 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 30 Oct 2023 10:02:27 -0400 Subject: [PATCH 21/23] moved authentication for the semantic index into the EmbeddingProvider --- crates/ai/src/auth.rs | 11 +-- crates/ai/src/completion.rs | 18 +--- crates/ai/src/embedding.rs | 15 +--- crates/ai/src/providers/open_ai/auth.rs | 46 ---------- crates/ai/src/providers/open_ai/completion.rs | 78 ++++++++++++---- crates/ai/src/providers/open_ai/embedding.rs | 88 ++++++++++++++----- crates/ai/src/providers/open_ai/mod.rs | 3 +- crates/ai/src/providers/open_ai/new.rs | 11 +++ crates/ai/src/test.rs | 48 +++++----- crates/assistant/src/assistant_panel.rs | 7 +- crates/assistant/src/codegen.rs | 8 +- crates/semantic_index/src/embedding_queue.rs | 17 +--- crates/semantic_index/src/semantic_index.rs | 50 ++++------- .../src/semantic_index_tests.rs | 6 +- 14 files changed, 200 insertions(+), 206 deletions(-) delete mode 100644 crates/ai/src/providers/open_ai/auth.rs create mode 100644 crates/ai/src/providers/open_ai/new.rs diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs index cb3f2beabb..c6256df216 100644 --- a/crates/ai/src/auth.rs +++ b/crates/ai/src/auth.rs @@ -8,17 +8,8 @@ pub enum ProviderCredential { } pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential); fn delete_credentials(&self, cx: &AppContext); } - -#[derive(Clone)] -pub struct NullCredentialProvider; -impl CredentialProvider for NullCredentialProvider { - fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { - ProviderCredential::NotNeeded - } - fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {} - fn delete_credentials(&self, cx: &AppContext) {} -} diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 6a2806a5cb..7fdc49e918 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,28 +1,14 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; -use gpui::AppContext; -use crate::{ - auth::{CredentialProvider, ProviderCredential}, - models::LanguageModel, -}; +use crate::{auth::CredentialProvider, models::LanguageModel}; pub trait CompletionRequest: Send + Sync { fn data(&self) -> serde_json::Result; } -pub trait CompletionProvider { +pub trait CompletionProvider: CredentialProvider { fn base_model(&self) -> Box; - fn credential_provider(&self) -> Box; - fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { - self.credential_provider().retrieve_credentials(cx) - } - fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { - self.credential_provider().save_credentials(cx, credential); - } - fn delete_credentials(&self, cx: &AppContext) { - self.credential_provider().delete_credentials(cx); - } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 50f04232ab..6768b7ce7b 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -2,12 +2,11 @@ use std::time::Instant; use anyhow::Result; use async_trait::async_trait; -use gpui::AppContext; use ordered_float::OrderedFloat; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; -use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::auth::CredentialProvider; use crate::models::LanguageModel; #[derive(Debug, PartialEq, Clone)] @@ -70,17 +69,9 @@ impl Embedding { } #[async_trait] -pub trait EmbeddingProvider: Sync + Send { +pub trait EmbeddingProvider: CredentialProvider { fn base_model(&self) -> Box; - fn credential_provider(&self) -> Box; - fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { - self.credential_provider().retrieve_credentials(cx) - } - async fn embed_batch( - &self, - spans: Vec, - credential: ProviderCredential, - ) -> Result>; + async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn rate_limit_expiration(&self) -> Option; } diff --git a/crates/ai/src/providers/open_ai/auth.rs b/crates/ai/src/providers/open_ai/auth.rs deleted file mode 100644 index 7cb51ab449..0000000000 --- a/crates/ai/src/providers/open_ai/auth.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::env; - -use gpui::AppContext; -use util::ResultExt; - -use crate::auth::{CredentialProvider, ProviderCredential}; -use crate::providers::open_ai::OPENAI_API_URL; - -#[derive(Clone)] -pub struct OpenAICredentialProvider {} - -impl CredentialProvider for OpenAICredentialProvider { - fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - - if let Some(api_key) = api_key { - ProviderCredential::Credentials { api_key } - } else { - ProviderCredential::NoCredentials - } - } - fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { - match credential { - ProviderCredential::Credentials { api_key } => { - cx.platform() - .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) - .log_err(); - } - _ => {} - } - } - fn delete_credentials(&self, cx: &AppContext) { - cx.platform().delete_credentials(OPENAI_API_URL).log_err(); - } -} diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index febe491123..02d25a7eec 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -3,14 +3,17 @@ use futures::{ future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt, }; -use gpui::executor::Background; +use gpui::{executor::Background, AppContext}; use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; use serde::{Deserialize, Serialize}; use std::{ + env, fmt::{self, Display}, io, sync::Arc, }; +use util::ResultExt; use crate::{ auth::{CredentialProvider, ProviderCredential}, @@ -18,9 +21,7 @@ use crate::{ models::LanguageModel, }; -use super::{auth::OpenAICredentialProvider, OpenAILanguageModel}; - -pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] @@ -194,42 +195,83 @@ pub async fn stream_completion( pub struct OpenAICompletionProvider { model: OpenAILanguageModel, - credential_provider: OpenAICredentialProvider, - credential: ProviderCredential, + credential: Arc>, executor: Arc, } impl OpenAICompletionProvider { - pub fn new( - model_name: &str, - credential: ProviderCredential, - executor: Arc, - ) -> Self { + pub fn new(model_name: &str, executor: Arc) -> Self { let model = OpenAILanguageModel::load(model_name); - let credential_provider = OpenAICredentialProvider {}; + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); Self { model, - credential_provider, credential, executor, } } } +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + impl CompletionProvider for OpenAICompletionProvider { fn base_model(&self) -> Box { let model: Box = Box::new(self.model.clone()); model } - fn credential_provider(&self) -> Box { - let provider: Box = Box::new(self.credential_provider.clone()); - provider - } fn complete( &self, prompt: Box, ) -> BoxFuture<'static, Result>>> { - let credential = self.credential.clone(); + let credential = self.credential.read().clone(); let request = stream_completion(credential, self.executor.clone(), prompt); async move { let response = request.await?; diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index dafc94580d..fbfd0028f9 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -2,27 +2,29 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::serde_json; +use gpui::{serde_json, AppContext}; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use parse_duration::parse; use postage::watch; use serde::{Deserialize, Serialize}; +use std::env; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; +use util::ResultExt; use crate::auth::{CredentialProvider, ProviderCredential}; use crate::embedding::{Embedding, EmbeddingProvider}; use crate::models::LanguageModel; use crate::providers::open_ai::OpenAILanguageModel; -use crate::providers::open_ai::auth::OpenAICredentialProvider; +use crate::providers::open_ai::OPENAI_API_URL; lazy_static! { static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); @@ -31,7 +33,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddingProvider { model: OpenAILanguageModel, - credential_provider: OpenAICredentialProvider, + credential: Arc>, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -69,10 +71,11 @@ impl OpenAIEmbeddingProvider { let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); OpenAIEmbeddingProvider { model, - credential_provider: OpenAICredentialProvider {}, + credential, client, executor, rate_limit_count_rx, @@ -80,6 +83,13 @@ impl OpenAIEmbeddingProvider { } } + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + fn resolve_rate_limit(&self) { let reset_time = *self.rate_limit_count_tx.lock().borrow(); @@ -136,6 +146,57 @@ impl OpenAIEmbeddingProvider { } } +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { fn base_model(&self) -> Box { @@ -143,12 +204,6 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { model } - fn credential_provider(&self) -> Box { - let credential_provider: Box = - Box::new(self.credential_provider.clone()); - credential_provider - } - fn max_tokens_per_batch(&self) -> usize { 50000 } @@ -157,18 +212,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { *self.rate_limit_count_rx.borrow() } - async fn embed_batch( - &self, - spans: Vec, - credential: ProviderCredential, - ) -> Result> { + async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = match credential { - ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key), - _ => Err(anyhow!("no api key provided")), - }?; + let api_key = self.get_api_key()?; let mut request_number = 0; let mut rate_limiting = false; diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 49e29fbc8c..7d2f86045d 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,4 +1,3 @@ -pub mod auth; pub mod completion; pub mod embedding; pub mod model; @@ -6,3 +5,5 @@ pub mod model; pub use completion::*; pub use embedding::*; pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai/src/providers/open_ai/new.rs b/crates/ai/src/providers/open_ai/new.rs new file mode 100644 index 0000000000..c7d67f2ba1 --- /dev/null +++ b/crates/ai/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index b8f99af400..bc9a6a3e43 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -5,10 +5,11 @@ use std::{ use async_trait::async_trait; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::AppContext; use parking_lot::Mutex; use crate::{ - auth::{CredentialProvider, NullCredentialProvider, ProviderCredential}, + auth::{CredentialProvider, ProviderCredential}, completion::{CompletionProvider, CompletionRequest}, embedding::{Embedding, EmbeddingProvider}, models::{LanguageModel, TruncationDirection}, @@ -52,14 +53,12 @@ impl LanguageModel for FakeLanguageModel { pub struct FakeEmbeddingProvider { pub embedding_count: AtomicUsize, - pub credential_provider: NullCredentialProvider, } impl Clone for FakeEmbeddingProvider { fn clone(&self) -> Self { FakeEmbeddingProvider { embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), - credential_provider: self.credential_provider.clone(), } } } @@ -68,7 +67,6 @@ impl Default for FakeEmbeddingProvider { fn default() -> Self { FakeEmbeddingProvider { embedding_count: AtomicUsize::default(), - credential_provider: NullCredentialProvider {}, } } } @@ -99,16 +97,22 @@ impl FakeEmbeddingProvider { } } +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { fn base_model(&self) -> Box { Box::new(FakeLanguageModel { capacity: 1000 }) } - fn credential_provider(&self) -> Box { - let credential_provider: Box = - Box::new(self.credential_provider.clone()); - credential_provider - } fn max_tokens_per_batch(&self) -> usize { 1000 } @@ -117,11 +121,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider { None } - async fn embed_batch( - &self, - spans: Vec, - _credential: ProviderCredential, - ) -> anyhow::Result> { + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); @@ -129,11 +129,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider { } } -pub struct TestCompletionProvider { +pub struct FakeCompletionProvider { last_completion_tx: Mutex>>, } -impl TestCompletionProvider { +impl FakeCompletionProvider { pub fn new() -> Self { Self { last_completion_tx: Mutex::new(None), @@ -150,14 +150,22 @@ impl TestCompletionProvider { } } -impl CompletionProvider for TestCompletionProvider { +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { fn base_model(&self) -> Box { let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); model } - fn credential_provider(&self) -> Box { - Box::new(NullCredentialProvider {}) - } fn complete( &self, _prompt: Box, diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index f9187b8785..c10ad2c362 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -10,7 +10,7 @@ use ai::{ auth::ProviderCredential, completion::{CompletionProvider, CompletionRequest}, providers::open_ai::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, }, }; @@ -48,7 +48,7 @@ use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::SettingsStore; use std::{ cell::{Cell, RefCell}, - cmp, env, + cmp, fmt::Write, iter, ops::Range, @@ -210,7 +210,6 @@ impl AssistantPanel { // Defaulting currently to GPT4, allow for this to be set via config. let completion_provider = Box::new(OpenAICompletionProvider::new( "gpt-4", - ProviderCredential::NoCredentials, cx.background().clone(), )); @@ -298,7 +297,6 @@ impl AssistantPanel { cx: &mut ViewContext, project: &ModelHandle, ) { - let credential = self.credential.borrow().clone(); let selection = editor.read(cx).selections.newest_anchor().clone(); if selection.start.excerpt_id() != selection.end.excerpt_id() { return; @@ -330,7 +328,6 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( "gpt-4", - credential, cx.background().clone(), )); diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 7f4c95f655..8d8e49902f 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,7 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use ai::test::TestCompletionProvider; + use ai::test::FakeCompletionProvider; use futures::stream::{self}; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; @@ -379,7 +379,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -445,7 +445,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -511,7 +511,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 6f792c78e2..6ae8faa4cd 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,5 +1,5 @@ use crate::{parsing::Span, JobHandle}; -use ai::{auth::ProviderCredential, embedding::EmbeddingProvider}; +use ai::embedding::EmbeddingProvider; use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; @@ -41,7 +41,6 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - pub provider_credential: ProviderCredential, } #[derive(Clone)] @@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed { } impl EmbeddingQueue { - pub fn new( - embedding_provider: Arc, - executor: Arc, - provider_credential: ProviderCredential, - ) -> Self { + pub fn new(embedding_provider: Arc, executor: Arc) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { embedding_provider, @@ -64,14 +59,9 @@ impl EmbeddingQueue { pending_batch_token_count: 0, finished_files_tx, finished_files_rx, - provider_credential, } } - pub fn set_credential(&mut self, credential: ProviderCredential) { - self.provider_credential = credential; - } - pub fn push(&mut self, file: FileToEmbed) { if file.spans.is_empty() { self.finished_files_tx.try_send(file).unwrap(); @@ -118,7 +108,6 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - let credential = self.provider_credential.clone(); self.executor .spawn(async move { @@ -143,7 +132,7 @@ impl EmbeddingQueue { return; }; - match embedding_provider.embed_batch(spans, credential).await { + match embedding_provider.embed_batch(spans).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 7fb5f749b4..818faa0444 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,7 +7,6 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; -use ai::auth::ProviderCredential; use ai::embedding::{Embedding, EmbeddingProvider}; use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; @@ -125,8 +124,6 @@ pub struct SemanticIndex { _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, - provider_credential: ProviderCredential, - embedding_queue: Arc>, } struct ProjectState { @@ -281,24 +278,17 @@ impl SemanticIndex { } pub fn authenticate(&mut self, cx: &AppContext) -> bool { - let existing_credential = self.provider_credential.clone(); - let credential = match existing_credential { - ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx), - _ => existing_credential, - }; + if !self.embedding_provider.has_credentials() { + self.embedding_provider.retrieve_credentials(cx); + } else { + return true; + } - self.provider_credential = credential.clone(); - self.embedding_queue.lock().set_credential(credential); - self.is_authenticated() + self.embedding_provider.has_credentials() } pub fn is_authenticated(&self) -> bool { - let credential = &self.provider_credential; - match credential { - &ProviderCredential::Credentials { .. } => true, - &ProviderCredential::NotNeeded => true, - _ => false, - } + self.embedding_provider.has_credentials() } pub fn enabled(cx: &AppContext) -> bool { @@ -348,7 +338,7 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); let embedding_queue = - EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials); + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); let db = db.clone(); @@ -413,8 +403,6 @@ impl SemanticIndex { _embedding_task, _parsing_files_tasks, projects: Default::default(), - provider_credential: ProviderCredential::NoCredentials, - embedding_queue } })) } @@ -729,14 +717,13 @@ impl SemanticIndex { let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); - let credential = self.provider_credential.clone(); cx.spawn(|this, mut cx| async move { index.await?; let t0 = Instant::now(); let query = embedding_provider - .embed_batch(vec![query], credential) + .embed_batch(vec![query]) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; @@ -954,7 +941,6 @@ impl SemanticIndex { let fs = self.fs.clone(); let db_path = self.db.path().clone(); let background = cx.background().clone(); - let credential = self.provider_credential.clone(); cx.background().spawn(async move { let db = VectorDatabase::new(fs, db_path.clone(), background).await?; let mut results = Vec::::new(); @@ -969,15 +955,10 @@ impl SemanticIndex { .parse_file_with_template(None, &snapshot.text(), language) .log_err() .unwrap_or_default(); - if Self::embed_spans( - &mut spans, - embedding_provider.as_ref(), - &db, - credential.clone(), - ) - .await - .log_err() - .is_some() + if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db) + .await + .log_err() + .is_some() { for span in spans { let similarity = span.embedding.unwrap().similarity(&query); @@ -1201,7 +1182,6 @@ impl SemanticIndex { spans: &mut [Span], embedding_provider: &dyn EmbeddingProvider, db: &VectorDatabase, - credential: ProviderCredential, ) -> Result<()> { let mut batch = Vec::new(); let mut batch_tokens = 0; @@ -1224,7 +1204,7 @@ impl SemanticIndex { if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), credential.clone()) + .embed_batch(mem::take(&mut batch)) .await?; embeddings.extend(batch_embeddings); batch_tokens = 0; @@ -1236,7 +1216,7 @@ impl SemanticIndex { if !batch.is_empty() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), credential) + .embed_batch(mem::take(&mut batch)) .await?; embeddings.extend(batch_embeddings); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 7d5a4e22e8..7a91d1e100 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -220,11 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new( - embedding_provider.clone(), - cx.background(), - ai::auth::ProviderCredential::NoCredentials, - ); + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background()); for file in &files { queue.push(file.clone()); } From f3c113fe02489748823b67f6e3340a094d412795 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 30 Oct 2023 11:07:24 -0400 Subject: [PATCH 22/23] clean up warnings and fix tests in the ai crate --- crates/ai/src/completion.rs | 7 + crates/ai/src/prompts/base.rs | 4 +- crates/ai/src/providers/open_ai/completion.rs | 8 + crates/ai/src/test.rs | 14 ++ crates/assistant/src/assistant_panel.rs | 212 ++++++------------ 5 files changed, 102 insertions(+), 143 deletions(-) diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 7fdc49e918..30a60fcf1d 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -13,4 +13,11 @@ pub trait CompletionProvider: CredentialProvider { &self, prompt: Box, ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() + } } diff --git a/crates/ai/src/prompts/base.rs b/crates/ai/src/prompts/base.rs index a2106c7410..75bad00154 100644 --- a/crates/ai/src/prompts/base.rs +++ b/crates/ai/src/prompts/base.rs @@ -147,7 +147,7 @@ pub(crate) mod tests { content = args.model.truncate( &content, max_token_length, - TruncationDirection::Start, + TruncationDirection::End, )?; token_count = max_token_length; } @@ -172,7 +172,7 @@ pub(crate) mod tests { content = args.model.truncate( &content, max_token_length, - TruncationDirection::Start, + TruncationDirection::End, )?; token_count = max_token_length; } diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 02d25a7eec..94685fd233 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -193,6 +193,7 @@ pub async fn stream_completion( } } +#[derive(Clone)] pub struct OpenAICompletionProvider { model: OpenAILanguageModel, credential: Arc>, @@ -271,6 +272,10 @@ impl CompletionProvider for OpenAICompletionProvider { &self, prompt: Box, ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. let credential = self.credential.read().clone(); let request = stream_completion(credential, self.executor.clone(), prompt); async move { @@ -287,4 +292,7 @@ impl CompletionProvider for OpenAICompletionProvider { } .boxed() } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } } diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index bc9a6a3e43..d4165f3cca 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -33,7 +33,10 @@ impl LanguageModel for FakeLanguageModel { length: usize, direction: TruncationDirection, ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); return anyhow::Ok(content.to_string()); } @@ -133,6 +136,14 @@ pub struct FakeCompletionProvider { last_completion_tx: Mutex>>, } +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + impl FakeCompletionProvider { pub fn new() -> Self { Self { @@ -174,4 +185,7 @@ impl CompletionProvider for FakeCompletionProvider { *self.last_completion_tx.lock() = Some(tx); async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index c10ad2c362..d0c7e7e883 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -9,9 +9,7 @@ use crate::{ use ai::{ auth::ProviderCredential, completion::{CompletionProvider, CompletionRequest}, - providers::open_ai::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, - }, + providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage}, }; use ai::prompts::repository_context::PromptCodeSnippet; @@ -47,7 +45,7 @@ use search::BufferSearchBar; use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::SettingsStore; use std::{ - cell::{Cell, RefCell}, + cell::Cell, cmp, fmt::Write, iter, @@ -144,10 +142,8 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - credential: Rc>, completion_provider: Box, api_key_editor: Option>, - has_read_credentials: bool, languages: Arc, fs: Arc, subscriptions: Vec, @@ -223,10 +219,8 @@ impl AssistantPanel { zoomed: false, has_focus: false, toolbar, - credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)), completion_provider, api_key_editor: None, - has_read_credentials: false, languages: workspace.app_state().languages.clone(), fs: workspace.app_state().fs.clone(), width: None, @@ -265,7 +259,7 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) { + if this.update(cx, |assistant, _| assistant.has_credentials()) { this } else { workspace.focus_panel::(cx); @@ -331,6 +325,9 @@ impl AssistantPanel { cx.background().clone(), )); + // Retrieve Credentials Authenticates the Provider + // provider.retrieve_credentials(cx); + let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) }); @@ -814,7 +811,7 @@ impl AssistantPanel { fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { ConversationEditor::new( - self.credential.clone(), + self.completion_provider.clone(), self.languages.clone(), self.fs.clone(), self.workspace.clone(), @@ -883,9 +880,8 @@ impl AssistantPanel { let credential = ProviderCredential::Credentials { api_key: api_key.clone(), }; - self.completion_provider - .save_credentials(cx, credential.clone()); - *self.credential.borrow_mut() = credential; + + self.completion_provider.save_credentials(cx, credential); self.api_key_editor.take(); cx.focus_self(); @@ -898,7 +894,6 @@ impl AssistantPanel { fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { self.completion_provider.delete_credentials(cx); - *self.credential.borrow_mut() = ProviderCredential::NoCredentials; self.api_key_editor = Some(build_api_key_editor(cx)); cx.focus_self(); cx.notify(); @@ -1157,19 +1152,12 @@ impl AssistantPanel { let fs = self.fs.clone(); let workspace = self.workspace.clone(); - let credential = self.credential.clone(); let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { let saved_conversation = fs.load(&path).await?; let saved_conversation = serde_json::from_str(&saved_conversation)?; let conversation = cx.add_model(|cx| { - Conversation::deserialize( - saved_conversation, - path.clone(), - credential, - languages, - cx, - ) + Conversation::deserialize(saved_conversation, path.clone(), languages, cx) }); this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened @@ -1193,39 +1181,12 @@ impl AssistantPanel { .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } - fn has_credentials(&mut self, cx: &mut ViewContext) -> bool { - let credential = self.load_credentials(cx); - match credential { - ProviderCredential::Credentials { .. } => true, - ProviderCredential::NotNeeded => true, - ProviderCredential::NoCredentials => false, - } + fn has_credentials(&mut self) -> bool { + self.completion_provider.has_credentials() } - fn load_credentials(&mut self, cx: &mut ViewContext) -> ProviderCredential { - let existing_credential = self.credential.clone(); - let existing_credential = existing_credential.borrow().clone(); - match existing_credential { - ProviderCredential::NoCredentials => { - if !self.has_read_credentials { - self.has_read_credentials = true; - let retrieved_credentials = self.completion_provider.retrieve_credentials(cx); - - match retrieved_credentials { - ProviderCredential::NoCredentials {} => { - self.api_key_editor = Some(build_api_key_editor(cx)); - cx.notify(); - } - _ => { - *self.credential.borrow_mut() = retrieved_credentials; - } - } - } - } - _ => {} - } - - self.credential.borrow().clone() + fn load_credentials(&mut self, cx: &mut ViewContext) { + self.completion_provider.retrieve_credentials(cx); } } @@ -1475,10 +1436,10 @@ struct Conversation { token_count: Option, max_token_count: usize, pending_token_count: Task>, - credential: Rc>, pending_save: Task>, path: Option, _subscriptions: Vec, + completion_provider: Box, } impl Entity for Conversation { @@ -1487,10 +1448,9 @@ impl Entity for Conversation { impl Conversation { fn new( - credential: Rc>, - language_registry: Arc, cx: &mut ModelContext, + completion_provider: Box, ) -> Self { let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { @@ -1529,8 +1489,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, - credential, buffer, + completion_provider, }; let message = MessageAnchor { id: MessageId(post_inc(&mut this.next_message_id.0)), @@ -1576,7 +1536,6 @@ impl Conversation { fn deserialize( saved_conversation: SavedConversation, path: PathBuf, - credential: Rc>, language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1585,6 +1544,10 @@ impl Conversation { None => Some(Uuid::new_v4().to_string()), }; let model = saved_conversation.model; + let completion_provider: Box = Box::new( + OpenAICompletionProvider::new(model.full_name(), cx.background().clone()), + ); + completion_provider.retrieve_credentials(cx); let markdown = language_registry.language_for_name("Markdown"); let mut message_anchors = Vec::new(); let mut next_message_id = MessageId(0); @@ -1631,8 +1594,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), - credential, buffer, + completion_provider, }; this.count_remaining_tokens(cx); this @@ -1753,12 +1716,8 @@ impl Conversation { } if should_assist { - let credential = self.credential.borrow().clone(); - match credential { - ProviderCredential::NoCredentials => { - return Default::default(); - } - _ => {} + if !self.completion_provider.has_credentials() { + return Default::default(); } let request: Box = Box::new(OpenAIRequest { @@ -1773,7 +1732,7 @@ impl Conversation { temperature: 1.0, }); - let stream = stream_completion(credential, cx.background().clone(), request); + let stream = self.completion_provider.complete(request); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -1791,33 +1750,28 @@ impl Conversation { let mut messages = stream.await?; while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - this.upgrade(&cx) - .ok_or_else(|| anyhow!("conversation was dropped"))? - .update(&mut cx, |this, cx| { - let text: Arc = choice.delta.content?.into(); - let message_ix = - this.message_anchors.iter().position(|message| { - message.id == assistant_message_id - })?; - this.buffer.update(cx, |buffer, cx| { - let offset = this.message_anchors[message_ix + 1..] - .iter() - .find(|message| message.start.is_valid(buffer)) - .map_or(buffer.len(), |message| { - message - .start - .to_offset(buffer) - .saturating_sub(1) - }); - buffer.edit([(offset..offset, text)], None, cx); - }); - cx.emit(ConversationEvent::StreamedCompletion); + let text = message?; - Some(()) + this.upgrade(&cx) + .ok_or_else(|| anyhow!("conversation was dropped"))? + .update(&mut cx, |this, cx| { + let message_ix = this + .message_anchors + .iter() + .position(|message| message.id == assistant_message_id)?; + this.buffer.update(cx, |buffer, cx| { + let offset = this.message_anchors[message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message.start.to_offset(buffer).saturating_sub(1) + }); + buffer.edit([(offset..offset, text)], None, cx); }); - } + cx.emit(ConversationEvent::StreamedCompletion); + + Some(()) + }); smol::future::yield_now().await; } @@ -2039,13 +1993,8 @@ impl Conversation { fn summarize(&mut self, cx: &mut ModelContext) { if self.message_anchors.len() >= 2 && self.summary.is_none() { - let credential = self.credential.borrow().clone(); - - match credential { - ProviderCredential::NoCredentials => { - return; - } - _ => {} + if !self.completion_provider.has_credentials() { + return; } let messages = self @@ -2065,23 +2014,20 @@ impl Conversation { temperature: 1.0, }); - let stream = stream_completion(credential, cx.background().clone(), request); + let stream = self.completion_provider.complete(request); self.pending_summary = cx.spawn(|this, mut cx| { async move { let mut messages = stream.await?; while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - let text = choice.delta.content.unwrap_or_default(); - this.update(&mut cx, |this, cx| { - this.summary - .get_or_insert(Default::default()) - .text - .push_str(&text); - cx.emit(ConversationEvent::SummaryChanged); - }); - } + let text = message?; + this.update(&mut cx, |this, cx| { + this.summary + .get_or_insert(Default::default()) + .text + .push_str(&text); + cx.emit(ConversationEvent::SummaryChanged); + }); } this.update(&mut cx, |this, cx| { @@ -2255,13 +2201,14 @@ struct ConversationEditor { impl ConversationEditor { fn new( - credential: Rc>, + completion_provider: Box, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, cx: &mut ViewContext, ) -> Self { - let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx)); + let conversation = + cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider)); Self::for_conversation(conversation, fs, workspace, cx) } @@ -3450,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { mod tests { use super::*; use crate::MessageId; + use ai::test::FakeCompletionProvider; use gpui::AppContext; #[gpui::test] @@ -3457,13 +3405,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry, - cx, - ) - }); + + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3591,13 +3535,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry, - cx, - ) - }); + let completion_provider = Box::new(FakeCompletionProvider::new()); + + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3693,13 +3633,8 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry, - cx, - ) - }); + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3781,13 +3716,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry.clone(), - cx, - ) - }); + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = + cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_0 = conversation.read(cx).message_anchors[0].id; let message_1 = conversation.update(cx, |conversation, cx| { @@ -3824,7 +3755,6 @@ mod tests { Conversation::deserialize( conversation.read(cx).serialize(cx), Default::default(), - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), registry.clone(), cx, ) From dc8a8538421102c1f41cbec1f79f5130ecb7ed35 Mon Sep 17 00:00:00 2001 From: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:27:05 +0100 Subject: [PATCH 23/23] lsp/next-ls: Fix wrong nls binary being fetched. (#3181) CPU types had to be swapped around. Fixed zed-industries/community#2185 Release Notes: - Fixed Elixir next-ls LSP installation failing due to fetching a binary for the wrong architecture (zed-industries/community#2185). --- crates/zed/src/languages/elixir.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/zed/src/languages/elixir.rs b/crates/zed/src/languages/elixir.rs index 5c0ff273ae..df438d89ee 100644 --- a/crates/zed/src/languages/elixir.rs +++ b/crates/zed/src/languages/elixir.rs @@ -321,8 +321,8 @@ impl LspAdapter for NextLspAdapter { latest_github_release("elixir-tools/next-ls", false, delegate.http_client()).await?; let version = release.name.clone(); let platform = match consts::ARCH { - "x86_64" => "darwin_arm64", - "aarch64" => "darwin_amd64", + "x86_64" => "darwin_amd64", + "aarch64" => "darwin_arm64", other => bail!("Running on unsupported platform: {other}"), }; let asset_name = format!("next_ls_{}", platform);