From 04ab68502b950e7d23c0522347b586ebd3670a4f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 30 Oct 2023 14:40:31 -0400 Subject: [PATCH] port ai crate to ai2, with all tests passing --- Cargo.lock | 28 ++ crates/Cargo.toml | 38 ++ crates/ai2/Cargo.toml | 38 ++ crates/ai2/src/ai2.rs | 8 + crates/ai2/src/auth.rs | 17 + crates/ai2/src/completion.rs | 23 ++ crates/ai2/src/embedding.rs | 123 +++++++ crates/ai2/src/models.rs | 16 + crates/ai2/src/prompts/base.rs | 330 ++++++++++++++++++ crates/ai2/src/prompts/file_context.rs | 164 +++++++++ crates/ai2/src/prompts/generate.rs | 99 ++++++ crates/ai2/src/prompts/mod.rs | 5 + crates/ai2/src/prompts/preamble.rs | 52 +++ crates/ai2/src/prompts/repository_context.rs | 98 ++++++ crates/ai2/src/providers/mod.rs | 1 + .../ai2/src/providers/open_ai/completion.rs | 306 ++++++++++++++++ crates/ai2/src/providers/open_ai/embedding.rs | 313 +++++++++++++++++ crates/ai2/src/providers/open_ai/mod.rs | 9 + crates/ai2/src/providers/open_ai/model.rs | 57 +++ crates/ai2/src/providers/open_ai/new.rs | 11 + crates/ai2/src/test.rs | 193 ++++++++++ crates/zed2/Cargo.toml | 1 + 22 files changed, 1930 insertions(+) create mode 100644 crates/Cargo.toml create mode 100644 crates/ai2/Cargo.toml create mode 100644 crates/ai2/src/ai2.rs create mode 100644 crates/ai2/src/auth.rs create mode 100644 crates/ai2/src/completion.rs create mode 100644 crates/ai2/src/embedding.rs create mode 100644 crates/ai2/src/models.rs create mode 100644 crates/ai2/src/prompts/base.rs create mode 100644 crates/ai2/src/prompts/file_context.rs create mode 100644 crates/ai2/src/prompts/generate.rs create mode 100644 crates/ai2/src/prompts/mod.rs create mode 100644 crates/ai2/src/prompts/preamble.rs create mode 100644 crates/ai2/src/prompts/repository_context.rs create mode 100644 crates/ai2/src/providers/mod.rs create mode 100644 crates/ai2/src/providers/open_ai/completion.rs create mode 100644 crates/ai2/src/providers/open_ai/embedding.rs create mode 100644 crates/ai2/src/providers/open_ai/mod.rs create mode 100644 crates/ai2/src/providers/open_ai/model.rs create mode 100644 crates/ai2/src/providers/open_ai/new.rs create mode 100644 crates/ai2/src/test.rs diff --git a/Cargo.lock b/Cargo.lock index 0caaaeceef..a5d187d08e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,6 +108,33 @@ dependencies = [ "util", ] +[[package]] +name = "ai2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "bincode", + "futures 0.3.28", + "gpui2", + "isahc", + "language2", + "lazy_static", + "log", + "matrixmultiply", + "ordered-float 2.10.0", + "parking_lot 0.11.2", + "parse_duration", + "postage", + "rand 0.8.5", + "regex", + "rusqlite", + "serde", + "serde_json", + "tiktoken-rs", + "util", +] + [[package]] name = "alacritty_config" version = "0.1.2-dev" @@ -10903,6 +10930,7 @@ dependencies = [ name = "zed2" version = "0.109.0" dependencies = [ + "ai2", "anyhow", "async-compression", "async-recursion 0.3.2", diff --git a/crates/Cargo.toml b/crates/Cargo.toml new file mode 100644 index 0000000000..fb49a4b515 --- /dev/null +++ b/crates/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui = { path = "../gpui" } +util = { path = "../util" } +language = { path = "../language" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/ai2/Cargo.toml b/crates/ai2/Cargo.toml new file mode 100644 index 0000000000..4f06840e8e --- /dev/null +++ b/crates/ai2/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai2.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui2 = { path = "../gpui2" } +util = { path = "../util" } +language2 = { path = "../language2" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui2 = { path = "../gpui2", features = ["test-support"] } diff --git a/crates/ai2/src/ai2.rs b/crates/ai2/src/ai2.rs new file mode 100644 index 0000000000..dda22d2a1d --- /dev/null +++ b/crates/ai2/src/ai2.rs @@ -0,0 +1,8 @@ +pub mod auth; +pub mod completion; +pub mod embedding; +pub mod models; +pub mod prompts; +pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai2/src/auth.rs b/crates/ai2/src/auth.rs new file mode 100644 index 0000000000..e4670bb449 --- /dev/null +++ b/crates/ai2/src/auth.rs @@ -0,0 +1,17 @@ +use async_trait::async_trait; +use gpui2::AppContext; + +#[derive(Clone, Debug)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +#[async_trait] +pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential; + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential); + async fn delete_credentials(&self, cx: &mut AppContext); +} diff --git a/crates/ai2/src/completion.rs b/crates/ai2/src/completion.rs new file mode 100644 index 0000000000..30a60fcf1d --- /dev/null +++ b/crates/ai2/src/completion.rs @@ -0,0 +1,23 @@ +use anyhow::Result; +use futures::{future::BoxFuture, stream::BoxStream}; + +use crate::{auth::CredentialProvider, models::LanguageModel}; + +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; +} + +pub trait CompletionProvider: CredentialProvider { + fn base_model(&self) -> Box; + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() + } +} diff --git a/crates/ai2/src/embedding.rs b/crates/ai2/src/embedding.rs new file mode 100644 index 0000000000..7ea4786178 --- /dev/null +++ b/crates/ai2/src/embedding.rs @@ -0,0 +1,123 @@ +use std::time::Instant; + +use anyhow::Result; +use async_trait::async_trait; +use ordered_float::OrderedFloat; +use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; +use rusqlite::ToSql; + +use crate::auth::CredentialProvider; +use crate::models::LanguageModel; + +#[derive(Debug, PartialEq, Clone)] +pub struct Embedding(pub Vec); + +// This is needed for semantic index functionality +// Unfortunately it has to live wherever the "Embedding" struct is created. +// Keeping this in here though, introduces a 'rusqlite' dependency into AI +// which is less than ideal +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + Ok(Embedding(embedding.unwrap())) + } +} + +impl ToSql for Embedding { + fn to_sql(&self) -> rusqlite::Result { + let bytes = bincode::serialize(&self.0) + .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; + Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) + } +} +impl From> for Embedding { + fn from(value: Vec) -> Self { + Embedding(value) + } +} + +impl Embedding { + pub fn similarity(&self, other: &Self) -> OrderedFloat { + let len = self.0.len(); + assert_eq!(len, other.0.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + self.0.as_ptr(), + len as isize, + 1, + other.0.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + OrderedFloat(result) + } +} + +#[async_trait] +pub trait EmbeddingProvider: CredentialProvider { + fn base_model(&self) -> Box; + async fn embed_batch(&self, spans: Vec) -> Result>; + fn max_tokens_per_batch(&self) -> usize; + fn rate_limit_expiration(&self) -> Option; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[gpui2::test] + fn test_similarity(mut rng: StdRng) { + assert_eq!( + Embedding::from(vec![1., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), + 0. + ); + assert_eq!( + Embedding::from(vec![2., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), + 6. + ); + + for _ in 0..100 { + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + let a = Embedding::from(a); + let b = Embedding::from(b); + + assert_eq!( + round_to_decimals(a.similarity(&b), 1), + round_to_decimals(reference_dot(&a.0, &b.0), 1) + ); + } + + fn round_to_decimals(n: OrderedFloat, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat { + OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()) + } + } +} diff --git a/crates/ai2/src/models.rs b/crates/ai2/src/models.rs new file mode 100644 index 0000000000..1db3d58c6f --- /dev/null +++ b/crates/ai2/src/models.rs @@ -0,0 +1,16 @@ +pub enum TruncationDirection { + Start, + End, +} + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/prompts/base.rs b/crates/ai2/src/prompts/base.rs new file mode 100644 index 0000000000..29091d0f5b --- /dev/null +++ b/crates/ai2/src/prompts/base.rs @@ -0,0 +1,330 @@ +use std::cmp::Reverse; +use std::ops::Range; +use std::sync::Arc; + +use language2::BufferSnapshot; +use util::ResultExt; + +use crate::models::LanguageModel; +use crate::prompts::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +// TODO: Set this up to manage for defaults well +pub struct PromptArguments { + pub model: Arc, + pub user_prompt: Option, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, Ord)] +pub enum PromptPriority { + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { + // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator)?; + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity()? - self.args.reserved_tokens) + } else { + None + }; + + let mut prompts = vec!["".to_string(); sorted_indices.len()]; + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + if template_prompt != "" { + prompts[idx] = template_prompt; + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } + } + + prompts.retain(|x| x != ""); + + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; + anyhow::Ok((prompts.join(seperator), total_token_count)) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; + + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + buffer: None, + selected_range: None, + user_prompt: None, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); + } +} diff --git a/crates/ai2/src/prompts/file_context.rs b/crates/ai2/src/prompts/file_context.rs new file mode 100644 index 0000000000..4a741beb24 --- /dev/null +++ b/crates/ai2/src/prompts/file_context.rs @@ -0,0 +1,164 @@ +use anyhow::anyhow; +use language2::BufferSnapshot; +use language2::ToOffset; + +use crate::models::LanguageModel; +use crate::models::TruncationDirection; +use crate::prompts::base::PromptArguments; +use crate::prompts::base::PromptTemplate; +use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } + } +} diff --git a/crates/ai2/src/prompts/generate.rs b/crates/ai2/src/prompts/generate.rs new file mode 100644 index 0000000000..c7be620107 --- /dev/null +++ b/crates/ai2/src/prompts/generate.rs @@ -0,0 +1,99 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; + } + + let token_count = args.model.count_tokens(&prompt)?; + + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai2/src/prompts/mod.rs b/crates/ai2/src/prompts/mod.rs new file mode 100644 index 0000000000..0025269a44 --- /dev/null +++ b/crates/ai2/src/prompts/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod file_context; +pub mod generate; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai2/src/prompts/preamble.rs b/crates/ai2/src/prompts/preamble.rs new file mode 100644 index 0000000000..92e0edeb78 --- /dev/null +++ b/crates/ai2/src/prompts/preamble.rs @@ -0,0 +1,52 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +pub struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); + + match args.get_file_type() { + PromptFileType::Code => { + prompts.push(format!( + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " + )); + } + PromptFileType::Text => { + prompts.push("You are an expert engineer.".to_string()); + } + } + + if let Some(project_name) = args.project_name.clone() { + prompts.push(format!( + "You are currently working inside the '{project_name}' project in code editor Zed." + )); + } + + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } + } +} diff --git a/crates/ai2/src/prompts/repository_context.rs b/crates/ai2/src/prompts/repository_context.rs new file mode 100644 index 0000000000..78db5a1651 --- /dev/null +++ b/crates/ai2/src/prompts/repository_context.rs @@ -0,0 +1,98 @@ +use crate::prompts::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; +use std::{ops::Range, path::PathBuf}; + +use gpui2::{AsyncAppContext, Handle}; +use language2::{Anchor, Buffer}; + +#[derive(Clone)] +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new( + buffer: Handle, + range: Range, + cx: &mut AsyncAppContext, + ) -> anyhow::Result { + let (content, language_name, file_path) = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string().to_lowercase())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + })?; + + anyhow::Ok(PromptCodeSnippet { + path: file_path, + language_name, + content, + }) + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/ai2/src/providers/mod.rs b/crates/ai2/src/providers/mod.rs new file mode 100644 index 0000000000..acd0f9d910 --- /dev/null +++ b/crates/ai2/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai2/src/providers/open_ai/completion.rs b/crates/ai2/src/providers/open_ai/completion.rs new file mode 100644 index 0000000000..eca5611027 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/completion.rs @@ -0,0 +1,306 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui2::{AppContext, Executor}; +use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::{ + env, + fmt::{self, Display}, + io, + sync::Arc, +}; +use util::ResultExt; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + credential: ProviderCredential, + executor: Arc, + request: Box, +) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = request.data()?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +#[derive(Clone)] +pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, + credential: Arc>, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(model_name: &str, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + Self { + model, + credential, + executor, + } + } +} + +#[async_trait] +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. + let credential = self.credential.read().clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/ai2/src/providers/open_ai/embedding.rs b/crates/ai2/src/providers/open_ai/embedding.rs new file mode 100644 index 0000000000..fc49c15134 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/embedding.rs @@ -0,0 +1,313 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui2::Executor; +use gpui2::{serde_json, AppContext}; +use isahc::http::StatusCode; +use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; +use lazy_static::lazy_static; +use parking_lot::{Mutex, RwLock}; +use parse_duration::parse; +use postage::watch; +use serde::{Deserialize, Serialize}; +use std::env; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tiktoken_rs::{cl100k_base, CoreBPE}; +use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; + +use crate::providers::open_ai::OPENAI_API_URL; + +lazy_static! { + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); +} + +#[derive(Clone)] +pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, + credential: Arc>, + pub client: Arc, + pub executor: Arc, + rate_limit_count_rx: watch::Receiver>, + rate_limit_count_tx: Arc>>>, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +impl OpenAIEmbeddingProvider { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + + OpenAIEmbeddingProvider { + model, + credential, + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + + fn resolve_rate_limit(&self) { + let reset_time = *self.rate_limit_count_tx.lock().borrow(); + + if let Some(reset_time) = reset_time { + if Instant::now() >= reset_time { + *self.rate_limit_count_tx.lock().borrow_mut() = None + } + } + + log::trace!( + "resolving reset time: {:?}", + *self.rate_limit_count_tx.lock().borrow() + ); + } + + fn update_reset_time(&self, reset_time: Instant) { + let original_time = *self.rate_limit_count_tx.lock().borrow(); + + let updated_time = if let Some(original_time) = original_time { + if reset_time < original_time { + Some(reset_time) + } else { + Some(original_time) + } + } else { + Some(reset_time) + }; + + log::trace!("updating rate limit time: {:?}", updated_time); + + *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; + } + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(request_timeout)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans.clone(), + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + Ok(self.client.send(request).await?) + } +} + +#[async_trait] +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn rate_limit_expiration(&self) -> Option { + *self.rate_limit_count_rx.borrow() + } + + async fn embed_batch(&self, spans: Vec) -> Result> { + const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; + const MAX_RETRIES: usize = 4; + + let api_key = self.get_api_key()?; + + let mut request_number = 0; + let mut rate_limiting = false; + let mut request_timeout: u64 = 15; + let mut response: Response; + while request_number < MAX_RETRIES { + response = self + .send_request( + &api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) + .await?; + + request_number += 1; + + match response.status() { + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::trace!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + + return Ok(response + .data + .into_iter() + .map(|embedding| Embedding::from(embedding.embedding)) + .collect()); + } + StatusCode::TOO_MANY_REQUESTS => { + rate_limiting = true; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + // If we've previously rate limited, increment the duration but not the count + let reset_time = Instant::now().add(delay_duration); + self.update_reset_time(reset_time); + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } + _ => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } + } + } + Err(anyhow!("openai max retries")) + } +} diff --git a/crates/ai2/src/providers/open_ai/mod.rs b/crates/ai2/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000..7d2f86045d --- /dev/null +++ b/crates/ai2/src/providers/open_ai/mod.rs @@ -0,0 +1,9 @@ +pub mod completion; +pub mod embedding; +pub mod model; + +pub use completion::*; +pub use embedding::*; +pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai2/src/providers/open_ai/model.rs b/crates/ai2/src/providers/open_ai/model.rs new file mode 100644 index 0000000000..6e306c80b9 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/model.rs @@ -0,0 +1,57 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +#[derive(Clone)] +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai2/src/providers/open_ai/new.rs b/crates/ai2/src/providers/open_ai/new.rs new file mode 100644 index 0000000000..c7d67f2ba1 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/test.rs b/crates/ai2/src/test.rs new file mode 100644 index 0000000000..ee88529aec --- /dev/null +++ b/crates/ai2/src/test.rs @@ -0,0 +1,193 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui2::AppContext; +use parking_lot::Mutex; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); + return anyhow::Ok(content.to_string()); + } + + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +#[async_trait] +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} + +pub struct FakeCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + +impl FakeCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +#[async_trait] +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/zed2/Cargo.toml b/crates/zed2/Cargo.toml index a6b31871dd..9f681a49e9 100644 --- a/crates/zed2/Cargo.toml +++ b/crates/zed2/Cargo.toml @@ -15,6 +15,7 @@ name = "Zed" path = "src/main.rs" [dependencies] +ai2 = { path = "../ai2"} # audio = { path = "../audio" } # activity_indicator = { path = "../activity_indicator" } # auto_update = { path = "../auto_update" }