From 76ce52df4ee0f4b4b977093f096c76e15b852ae3 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 30 Aug 2023 16:01:28 -0400 Subject: [PATCH] move queuing to embedding_queue functionality and update embedding provider to include trait items for max tokens per batch" Co-authored-by: Max --- crates/semantic_index/src/embedding.rs | 47 ++---- crates/semantic_index/src/embedding_queue.rs | 140 ++++++++++++++++ crates/semantic_index/src/parsing.rs | 10 +- .../src/semantic_index_tests.rs | 154 +++++++++++++----- crates/util/src/util.rs | 35 ++-- 5 files changed, 295 insertions(+), 91 deletions(-) create mode 100644 crates/semantic_index/src/embedding_queue.rs diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index cba34439c8..7db22c3716 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -53,36 +53,30 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; - fn count_tokens(&self, span: &str) -> usize; - fn should_truncate(&self, span: &str) -> bool; - fn truncate(&self, span: &str) -> String; + async fn embed_batch(&self, spans: Vec) -> Result>>; + fn max_tokens_per_batch(&self) -> usize; + fn truncate(&self, span: &str) -> (String, usize); } pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result>> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. let dummy_vec = vec![0.32 as f32; 1536]; return Ok(vec![dummy_vec; spans.len()]); } - fn count_tokens(&self, span: &str) -> usize { - // For Dummy Providers, we are going to use OpenAI tokenization for ease - let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - tokens.len() + fn max_tokens_per_batch(&self) -> usize { + OPENAI_INPUT_LIMIT } - fn should_truncate(&self, span: &str) -> bool { - self.count_tokens(span) > OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> String { + fn truncate(&self, span: &str) -> (String, usize) { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { + let token_count = tokens.len(); + let output = if token_count > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); OPENAI_BPE_TOKENIZER .decode(tokens) @@ -92,7 +86,7 @@ impl EmbeddingProvider for DummyEmbeddings { span.to_string() }; - output + (output, token_count) } } @@ -125,19 +119,14 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { - fn count_tokens(&self, span: &str) -> usize { - // For Dummy Providers, we are going to use OpenAI tokenization for ease - let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - tokens.len() + fn max_tokens_per_batch(&self) -> usize { + OPENAI_INPUT_LIMIT } - fn should_truncate(&self, span: &str) -> bool { - self.count_tokens(span) > OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> String { + fn truncate(&self, span: &str) -> (String, usize) { let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { + let token_count = tokens.len(); + let output = if token_count > OPENAI_INPUT_LIMIT { tokens.truncate(OPENAI_INPUT_LIMIT); OPENAI_BPE_TOKENIZER .decode(tokens) @@ -147,10 +136,10 @@ impl EmbeddingProvider for OpenAIEmbeddings { span.to_string() }; - output + (output, token_count) } - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result>> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; @@ -160,9 +149,7 @@ impl EmbeddingProvider for OpenAIEmbeddings { let mut request_number = 0; let mut request_timeout: u64 = 10; - let mut truncated = false; let mut response: Response; - let mut spans: Vec = spans.iter().map(|x| x.to_string()).collect(); while request_number < MAX_RETRIES { response = self .send_request( diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs new file mode 100644 index 0000000000..6609c39e78 --- /dev/null +++ b/crates/semantic_index/src/embedding_queue.rs @@ -0,0 +1,140 @@ +use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; + +use gpui::AppContext; +use parking_lot::Mutex; +use smol::channel; + +use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle}; + +#[derive(Clone)] +pub struct FileToEmbed { + pub worktree_id: i64, + pub path: PathBuf, + pub mtime: SystemTime, + pub documents: Vec, + pub job_handle: JobHandle, +} + +impl std::fmt::Debug for FileToEmbed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FileToEmbed") + .field("worktree_id", &self.worktree_id) + .field("path", &self.path) + .field("mtime", &self.mtime) + .field("document", &self.documents) + .finish_non_exhaustive() + } +} + +impl PartialEq for FileToEmbed { + fn eq(&self, other: &Self) -> bool { + self.worktree_id == other.worktree_id + && self.path == other.path + && self.mtime == other.mtime + && self.documents == other.documents + } +} + +pub struct EmbeddingQueue { + embedding_provider: Arc, + pending_batch: Vec, + pending_batch_token_count: usize, + finished_files_tx: channel::Sender, + finished_files_rx: channel::Receiver, +} + +pub struct FileToEmbedFragment { + file: Arc>, + document_range: Range, +} + +impl EmbeddingQueue { + pub fn new(embedding_provider: Arc) -> Self { + let (finished_files_tx, finished_files_rx) = channel::unbounded(); + Self { + embedding_provider, + pending_batch: Vec::new(), + pending_batch_token_count: 0, + finished_files_tx, + finished_files_rx, + } + } + + pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) { + let file = Arc::new(Mutex::new(file)); + + self.pending_batch.push(FileToEmbedFragment { + file: file.clone(), + document_range: 0..0, + }); + + let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + for (ix, document) in file.lock().documents.iter().enumerate() { + let next_token_count = self.pending_batch_token_count + document.token_count; + if next_token_count > self.embedding_provider.max_tokens_per_batch() { + let range_end = fragment_range.end; + self.flush(cx); + self.pending_batch.push(FileToEmbedFragment { + file: file.clone(), + document_range: range_end..range_end, + }); + fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range; + } + + fragment_range.end = ix + 1; + self.pending_batch_token_count += document.token_count; + } + } + + pub fn flush(&mut self, cx: &mut AppContext) { + let batch = mem::take(&mut self.pending_batch); + self.pending_batch_token_count = 0; + if batch.is_empty() { + return; + } + + let finished_files_tx = self.finished_files_tx.clone(); + let embedding_provider = self.embedding_provider.clone(); + cx.background().spawn(async move { + let mut spans = Vec::new(); + for fragment in &batch { + let file = fragment.file.lock(); + spans.extend( + file.documents[fragment.document_range.clone()] + .iter() + .map(|d| d.content.clone()), + ); + } + + match embedding_provider.embed_batch(spans).await { + Ok(embeddings) => { + let mut embeddings = embeddings.into_iter(); + for fragment in batch { + for document in + &mut fragment.file.lock().documents[fragment.document_range.clone()] + { + if let Some(embedding) = embeddings.next() { + document.embedding = embedding; + } else { + // + log::error!("number of embeddings returned different from number of documents"); + } + } + + if let Some(file) = Arc::into_inner(fragment.file) { + finished_files_tx.try_send(file.into_inner()).unwrap(); + } + } + } + Err(error) => { + log::error!("{:?}", error); + } + } + }) + .detach(); + } + + pub fn finished_files(&self) -> channel::Receiver { + self.finished_files_rx.clone() + } +} diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index 00849580bb..51f1bd7ca9 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -72,8 +72,7 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); - let token_count = self.embedding_provider.count_tokens(&document_span); - let document_span = self.embedding_provider.truncate(&document_span); + let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -93,8 +92,7 @@ impl CodeContextRetriever { let mut sha1 = Sha1::new(); sha1.update(&document_span); - let token_count = self.embedding_provider.count_tokens(&document_span); - let document_span = self.embedding_provider.truncate(&document_span); + let (document_span, token_count) = self.embedding_provider.truncate(&document_span); Ok(vec![Document { range: 0..content.len(), @@ -183,8 +181,8 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("item", &document.content); - let token_count = self.embedding_provider.count_tokens(&document_content); - let document_content = self.embedding_provider.truncate(&document_content); + let (document_content, token_count) = + self.embedding_provider.truncate(&document_content); document.content = document_content; document.token_count = token_count; diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 7093cf9fcf..7178987165 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1,14 +1,16 @@ use crate::{ db::dot, embedding::{DummyEmbeddings, EmbeddingProvider}, + embedding_queue::EmbeddingQueue, parsing::{subtract_ranges, CodeContextRetriever, Document}, semantic_index_settings::SemanticIndexSettings, - SearchResult, SemanticIndex, + FileToEmbed, JobHandle, SearchResult, SemanticIndex, }; use anyhow::Result; use async_trait::async_trait; use gpui::{Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; +use parking_lot::Mutex; use pretty_assertions::assert_eq; use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project}; use rand::{rngs::StdRng, Rng}; @@ -20,8 +22,10 @@ use std::{ atomic::{self, AtomicUsize}, Arc, }, + time::SystemTime, }; use unindent::Unindent; +use util::RandomCharIter; #[ctor::ctor] fn init_logger() { @@ -32,11 +36,7 @@ fn init_logger() { #[gpui::test] async fn test_semantic_index(cx: &mut TestAppContext) { - cx.update(|cx| { - cx.set_global(SettingsStore::test(cx)); - settings::register::(cx); - settings::register::(cx); - }); + init_test(cx); let fs = FakeFs::new(cx.background()); fs.insert_tree( @@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let db_path = db_dir.path().join("db.sqlite"); let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let store = SemanticIndex::new( + let semantic_index = SemanticIndex::new( fs.clone(), db_path, embedding_provider.clone(), @@ -87,13 +87,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) { let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - let _ = store + let _ = semantic_index .update(cx, |store, cx| { store.initialize_project(project.clone(), cx) }) .await; - let (file_count, outstanding_file_count) = store + let (file_count, outstanding_file_count) = semantic_index .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); @@ -101,7 +101,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx.foreground().run_until_parked(); assert_eq!(*outstanding_file_count.borrow(), 0); - let search_results = store + let search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -129,7 +129,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { // Test Include Files Functonality let include_files = vec![PathMatcher::new("*.rs").unwrap()]; let exclude_files = vec![PathMatcher::new("*.rs").unwrap()]; - let rust_only_search_results = store + let rust_only_search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -153,7 +153,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx, ); - let no_rust_search_results = store + let no_rust_search_results = semantic_index .update(cx, |store, cx| { store.search_project( project.clone(), @@ -189,7 +189,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) { cx.foreground().run_until_parked(); let prev_embedding_count = embedding_provider.embedding_count(); - let (file_count, outstanding_file_count) = store + let (file_count, outstanding_file_count) = semantic_index .update(cx, |store, cx| store.index_project(project.clone(), cx)) .await .unwrap(); @@ -204,6 +204,69 @@ async fn test_semantic_index(cx: &mut TestAppContext) { ); } +#[gpui::test(iterations = 10)] +async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { + let (outstanding_job_count, _) = postage::watch::channel_with(0); + let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count)); + + let files = (1..=3) + .map(|file_ix| FileToEmbed { + worktree_id: 5, + path: format!("path-{file_ix}").into(), + mtime: SystemTime::now(), + documents: (0..rng.gen_range(4..22)) + .map(|document_ix| { + let content_len = rng.gen_range(10..100); + Document { + range: 0..10, + embedding: Vec::new(), + name: format!("document {document_ix}"), + content: RandomCharIter::new(&mut rng) + .with_simple_text() + .take(content_len) + .collect(), + sha1: rng.gen(), + token_count: rng.gen_range(10..30), + } + }) + .collect(), + job_handle: JobHandle::new(&outstanding_job_count), + }) + .collect::>(); + + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); + let mut queue = EmbeddingQueue::new(embedding_provider.clone()); + + let finished_files = cx.update(|cx| { + for file in &files { + queue.push(file.clone(), cx); + } + queue.flush(cx); + queue.finished_files() + }); + + cx.foreground().run_until_parked(); + let mut embedded_files: Vec<_> = files + .iter() + .map(|_| finished_files.try_recv().expect("no finished file")) + .collect(); + + let expected_files: Vec<_> = files + .iter() + .map(|file| { + let mut file = file.clone(); + for doc in &mut file.documents { + doc.embedding = embedding_provider.embed_sync(doc.content.as_ref()); + } + file + }) + .collect(); + + embedded_files.sort_by_key(|f| f.path.clone()); + + assert_eq!(embedded_files, expected_files); +} + #[track_caller] fn assert_search_results( actual: &[SearchResult], @@ -1220,47 +1283,42 @@ impl FakeEmbeddingProvider { fn embedding_count(&self) -> usize { self.embedding_count.load(atomic::Ordering::SeqCst) } + + fn embed_sync(&self, span: &str) -> Vec { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result + } } #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { - fn count_tokens(&self, span: &str) -> usize { - span.len() + fn truncate(&self, span: &str) -> (String, usize) { + (span.to_string(), 1) } - fn should_truncate(&self, span: &str) -> bool { - false + fn max_tokens_per_batch(&self) -> usize { + 200 } - fn truncate(&self, span: &str) -> String { - span.to_string() - } - - async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + async fn embed_batch(&self, spans: Vec) -> Result>> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); - Ok(spans - .iter() - .map(|span| { - let mut result = vec![1.0; 26]; - for letter in span.chars() { - let letter = letter.to_ascii_lowercase(); - if letter as u32 >= 'a' as u32 { - let ix = (letter as u32) - ('a' as u32); - if ix < 26 { - result[ix as usize] += 1.0; - } - } - } - - let norm = result.iter().map(|x| x * x).sum::().sqrt(); - for x in &mut result { - *x /= norm; - } - - result - }) - .collect()) + Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } } @@ -1704,3 +1762,11 @@ fn test_subtract_ranges() { assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]); } + +fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + cx.set_global(SettingsStore::test(cx)); + settings::register::(cx); + settings::register::(cx); + }); +} diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index c8beb86aef..785426ed4c 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -260,11 +260,22 @@ pub fn defer(f: F) -> impl Drop { Defer(Some(f)) } -pub struct RandomCharIter(T); +pub struct RandomCharIter { + rng: T, + simple_text: bool, +} impl RandomCharIter { pub fn new(rng: T) -> Self { - Self(rng) + Self { + rng, + simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()), + } + } + + pub fn with_simple_text(mut self) -> Self { + self.simple_text = true; + self } } @@ -272,25 +283,27 @@ impl Iterator for RandomCharIter { type Item = char; fn next(&mut self) -> Option { - if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) { - return if self.0.gen_range(0..100) < 5 { + if self.simple_text { + return if self.rng.gen_range(0..100) < 5 { Some('\n') } else { - Some(self.0.gen_range(b'a'..b'z' + 1).into()) + Some(self.rng.gen_range(b'a'..b'z' + 1).into()) }; } - match self.0.gen_range(0..100) { + match self.rng.gen_range(0..100) { // whitespace - 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(), + 0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(), // two-byte greek letters - 20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))), + 20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))), // // three-byte characters - 33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(), + 33..=45 => ['✋', '✅', '❌', '❎', '⭐'] + .choose(&mut self.rng) + .copied(), // // four-byte characters - 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(), + 46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(), // ascii letters - _ => Some(self.0.gen_range(b'a'..b'z' + 1).into()), + _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()), } } }