From afccf608f42d9b35d6b1942ae60734f3b3e8d3a9 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 5 Jul 2023 12:39:08 -0400 Subject: [PATCH] updated both embed and parsing tasks to be multi-threaded. --- Cargo.lock | 34 +- crates/vector_store/Cargo.toml | 1 + crates/vector_store/src/embedding.rs | 27 +- crates/vector_store/src/vector_store.rs | 411 +++++++++++++----------- 4 files changed, 281 insertions(+), 192 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 59cf30001e..dbc2a1cbb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,7 +118,7 @@ dependencies = [ "settings", "smol", "theme", - "tiktoken-rs", + "tiktoken-rs 0.4.2", "util", "workspace", ] @@ -737,9 +737,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.0" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "base64ct" @@ -914,9 +914,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d4260bcc2e8fc9df1eac4919a720effeb63a3f0952f5bf4944adfa18897f09" +checksum = "a246e68bb43f6cd9db24bea052a53e40405417c5fb372e3d1a8a7f770a564ef5" dependencies = [ "memchr", "once_cell", @@ -4812,7 +4812,7 @@ version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bd9647b268a3d3e14ff09c23201133a62589c658db02bb7388c7246aafe0590" dependencies = [ - "base64 0.21.0", + "base64 0.21.2", "indexmap", "line-wrap", "quick-xml", @@ -5529,7 +5529,7 @@ version = "0.11.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13293b639a097af28fc8a90f22add145a9c954e49d77da06263d58cf44d5fb91" dependencies = [ - "base64 0.21.0", + "base64 0.21.2", "bytes 1.4.0", "encoding_rs", "futures-core", @@ -5868,7 +5868,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.21.0", + "base64 0.21.2", ] [[package]] @@ -7118,7 +7118,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ba161c549e2c0686f35f5d920e63fad5cafba2c28ad2caceaf07e5d9fa6e8c4" dependencies = [ "anyhow", - "base64 0.21.0", + "base64 0.21.2", + "bstr", + "fancy-regex", + "lazy_static", + "parking_lot 0.12.1", + "rustc-hash", +] + +[[package]] +name = "tiktoken-rs" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a99d843674a3468b4a9200a565bbe909a0152f95e82a52feae71e6bf2d4b49d" +dependencies = [ + "anyhow", + "base64 0.21.2", "bstr", "fancy-regex", "lazy_static", @@ -8038,6 +8053,7 @@ dependencies = [ "smol", "tempdir", "theme", + "tiktoken-rs 0.5.0", "tree-sitter", "tree-sitter-rust", "unindent", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index d1ad8a0f9b..854afe5b6e 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -31,6 +31,7 @@ serde_json.workspace = true async-trait.workspace = true bincode = "1.3.3" matrixmultiply = "0.3.7" +tiktoken-rs = "0.5.0" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 86d8494ab4..72b30d9424 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -5,8 +5,8 @@ use gpui::serde_json; use isahc::prelude::Configurable; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; -use std::env; use std::sync::Arc; +use std::{env, time::Instant}; use util::http::{HttpClient, Request}; lazy_static! { @@ -60,9 +60,34 @@ impl EmbeddingProvider for DummyEmbeddings { } } +// impl OpenAIEmbeddings { +// async fn truncate(span: &str) -> String { +// let bpe = cl100k_base().unwrap(); +// let mut tokens = bpe.encode_with_special_tokens(span); +// if tokens.len() > 8192 { +// tokens.truncate(8192); +// let result = bpe.decode(tokens); +// if result.is_ok() { +// return result.unwrap(); +// } +// } + +// return span.to_string(); +// } +// } + #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + // Truncate spans to 8192 if needed + // let t0 = Instant::now(); + // let mut truncated_spans = vec![]; + // for span in spans { + // truncated_spans.push(Self::truncate(span)); + // } + // let spans = futures::future::join_all(truncated_spans).await; + // log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs()); + let api_key = OPENAI_API_KEY .as_ref() .ok_or_else(|| anyhow!("no api key"))?; diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index e072793e25..a63674bc34 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -34,7 +34,7 @@ use util::{ use workspace::{Workspace, WorkspaceCreated}; const REINDEXING_DELAY: u64 = 30; -const EMBEDDINGS_BATCH_SIZE: usize = 25; +const EMBEDDINGS_BATCH_SIZE: usize = 150; #[derive(Debug, Clone)] pub struct Document { @@ -74,6 +74,7 @@ pub fn init( cx.subscribe_global::({ let vector_store = vector_store.clone(); move |event, cx| { + let t0 = Instant::now(); let workspace = &event.0; if let Some(workspace) = workspace.upgrade(cx) { let project = workspace.read(cx).project().clone(); @@ -124,10 +125,14 @@ pub struct VectorStore { embedding_provider: Arc, language_registry: Arc, db_update_tx: channel::Sender, - paths_tx: channel::Sender<(i64, PathBuf, Arc, SystemTime)>, + // embed_batch_tx: channel::Sender)>>, + batch_files_tx: channel::Sender<(i64, IndexedFile, Vec)>, + parsing_files_tx: channel::Sender<(i64, PathBuf, Arc, SystemTime)>, + parsing_files_rx: channel::Receiver<(i64, PathBuf, Arc, SystemTime)>, _db_update_task: Task<()>, - _paths_update_task: Task<()>, - _embeddings_update_task: Task<()>, + _embed_batch_task: Vec>, + _batch_files_task: Task<()>, + _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, } @@ -188,12 +193,8 @@ impl VectorStore { Ok(cx.add_model(|cx| { // paths_tx -> embeddings_tx -> db_update_tx + //db_update_tx/rx: Updating Database let (db_update_tx, db_update_rx) = channel::unbounded(); - let (paths_tx, paths_rx) = - channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); - let (embeddings_tx, embeddings_rx) = - channel::unbounded::<(i64, IndexedFile, Vec)>(); - let _db_update_task = cx.background().spawn(async move { while let Ok(job) = db_update_rx.recv().await { match job { @@ -201,6 +202,7 @@ impl VectorStore { worktree_id, indexed_file, } => { + log::info!("Inserting Data for {:?}", &indexed_file.path); db.insert_file(worktree_id, indexed_file).log_err(); } DbWrite::Delete { worktree_id, path } => { @@ -214,132 +216,137 @@ impl VectorStore { } }); - async fn embed_batch( - embeddings_queue: Vec<(i64, IndexedFile, Vec)>, - embedding_provider: &Arc, - db_update_tx: channel::Sender, - ) -> Result<()> { - let mut embeddings_queue = embeddings_queue.clone(); + // embed_tx/rx: Embed Batch and Send to Database + let (embed_batch_tx, embed_batch_rx) = + channel::unbounded::)>>(); + let mut _embed_batch_task = Vec::new(); + for _ in 0..cx.background().num_cpus() { + let db_update_tx = db_update_tx.clone(); + let embed_batch_rx = embed_batch_rx.clone(); + let embedding_provider = embedding_provider.clone(); + _embed_batch_task.push(cx.background().spawn(async move { + while let Ok(embeddings_queue) = embed_batch_rx.recv().await { + log::info!("Embedding Batch! "); - let mut document_spans = vec![]; - for (_, _, document_span) in embeddings_queue.clone().into_iter() { - document_spans.extend(document_span); - } + // Construct Batch + let mut embeddings_queue = embeddings_queue.clone(); + let mut document_spans = vec![]; + for (_, _, document_span) in embeddings_queue.clone().into_iter() { + document_spans.extend(document_span); + } - let mut embeddings = embedding_provider - .embed_batch(document_spans.iter().map(|x| &**x).collect()) - .await?; + if let Some(mut embeddings) = embedding_provider + .embed_batch(document_spans.iter().map(|x| &**x).collect()) + .await + .log_err() + { + let mut i = 0; + let mut j = 0; + while let Some(embedding) = embeddings.pop() { + while embeddings_queue[i].1.documents.len() == j { + i += 1; + j = 0; + } - // This assumes the embeddings are returned in order - let t0 = Instant::now(); - let mut i = 0; - let mut j = 0; - while let Some(embedding) = embeddings.pop() { - // This has to accomodate for multiple indexed_files in a row without documents - while embeddings_queue[i].1.documents.len() == j { - i += 1; - j = 0; + embeddings_queue[i].1.documents[j].embedding = embedding; + j += 1; + } + + for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() { + for document in indexed_file.documents.iter() { + // TODO: Update this so it doesn't panic + assert!( + document.embedding.len() > 0, + "Document Embedding Not Complete" + ); + } + + db_update_tx + .send(DbWrite::InsertFile { + worktree_id, + indexed_file, + }) + .await + .unwrap(); + } + } } - - embeddings_queue[i].1.documents[j].embedding = embedding; - j += 1; - } - - for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() { - // TODO: Update this so it doesnt panic - for document in indexed_file.documents.iter() { - assert!( - document.embedding.len() > 0, - "Document Embedding not Complete" - ); - } - - db_update_tx - .send(DbWrite::InsertFile { - worktree_id, - indexed_file, - }) - .await - .unwrap(); - } - - anyhow::Ok(()) + })) } - let embedding_provider_clone = embedding_provider.clone(); - - let db_update_tx_clone = db_update_tx.clone(); - let _embeddings_update_task = cx.background().spawn(async move { + // batch_tx/rx: Batch Files to Send for Embeddings + let (batch_files_tx, batch_files_rx) = + channel::unbounded::<(i64, IndexedFile, Vec)>(); + let _batch_files_task = cx.background().spawn(async move { let mut queue_len = 0; let mut embeddings_queue = vec![]; - let mut request_count = 0; while let Ok((worktree_id, indexed_file, document_spans)) = - embeddings_rx.recv().await + batch_files_rx.recv().await { + log::info!("Batching File: {:?}", &indexed_file.path); queue_len += &document_spans.len(); embeddings_queue.push((worktree_id, indexed_file, document_spans)); - if queue_len >= EMBEDDINGS_BATCH_SIZE { - let _ = embed_batch( - embeddings_queue, - &embedding_provider_clone, - db_update_tx_clone.clone(), - ) - .await; - + embed_batch_tx.try_send(embeddings_queue).unwrap(); embeddings_queue = vec![]; queue_len = 0; - - request_count += 1; } } - if queue_len > 0 { - let _ = embed_batch( - embeddings_queue, - &embedding_provider_clone, - db_update_tx_clone.clone(), - ) - .await; - request_count += 1; + embed_batch_tx.try_send(embeddings_queue).unwrap(); } }); - let fs_clone = fs.clone(); + // parsing_files_tx/rx: Parsing Files to Embeddable Documents + let (parsing_files_tx, parsing_files_rx) = + channel::unbounded::<(i64, PathBuf, Arc, SystemTime)>(); - let _paths_update_task = cx.background().spawn(async move { - let mut parser = Parser::new(); - let mut cursor = QueryCursor::new(); - while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await { - if let Some((indexed_file, document_spans)) = Self::index_file( - &mut cursor, - &mut parser, - &fs_clone, - language, - file_path.clone(), - mtime, - ) - .await - .log_err() + let mut _parsing_files_tasks = Vec::new(); + for _ in 0..cx.background().num_cpus() { + let fs = fs.clone(); + let parsing_files_rx = parsing_files_rx.clone(); + let batch_files_tx = batch_files_tx.clone(); + _parsing_files_tasks.push(cx.background().spawn(async move { + let mut parser = Parser::new(); + let mut cursor = QueryCursor::new(); + while let Ok((worktree_id, file_path, language, mtime)) = + parsing_files_rx.recv().await { - embeddings_tx - .try_send((worktree_id, indexed_file, document_spans)) - .unwrap(); + log::info!("Parsing File: {:?}", &file_path); + if let Some((indexed_file, document_spans)) = Self::index_file( + &mut cursor, + &mut parser, + &fs, + language, + file_path.clone(), + mtime, + ) + .await + .log_err() + { + batch_files_tx + .try_send((worktree_id, indexed_file, document_spans)) + .unwrap(); + } } - } - }); + })); + } Self { fs, database_url, - db_update_tx, - paths_tx, embedding_provider, language_registry, - projects: HashMap::new(), + db_update_tx, + // embed_batch_tx, + batch_files_tx, + parsing_files_tx, + parsing_files_rx, _db_update_task, - _paths_update_task, - _embeddings_update_task, + _embed_batch_task, + _batch_files_task, + _parsing_files_tasks, + projects: HashMap::new(), } })) } @@ -441,12 +448,16 @@ impl VectorStore { let language_registry = self.language_registry.clone(); let database_url = self.database_url.clone(); let db_update_tx = self.db_update_tx.clone(); - let paths_tx = self.paths_tx.clone(); + let parsing_files_tx = self.parsing_files_tx.clone(); + let parsing_files_rx = self.parsing_files_rx.clone(); + let batch_files_tx = self.batch_files_tx.clone(); cx.spawn(|this, mut cx| async move { + let t0 = Instant::now(); futures::future::join_all(worktree_scans_complete).await; let worktree_db_ids = futures::future::join_all(worktree_db_ids).await; + log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis()); if let Some(db_directory) = database_url.parent() { fs.create_dir(db_directory).await.log_err(); @@ -485,8 +496,9 @@ impl VectorStore { let db_ids_by_worktree_id = db_ids_by_worktree_id.clone(); let db_update_tx = db_update_tx.clone(); let language_registry = language_registry.clone(); - let paths_tx = paths_tx.clone(); + let parsing_files_tx = parsing_files_tx.clone(); async move { + let t0 = Instant::now(); for worktree in worktrees.into_iter() { let mut file_mtimes = worktree_file_times.remove(&worktree.id()).unwrap(); @@ -513,7 +525,7 @@ impl VectorStore { }); if !already_stored { - paths_tx + parsing_files_tx .try_send(( db_ids_by_worktree_id[&worktree.id()], path_buf, @@ -533,10 +545,45 @@ impl VectorStore { .unwrap(); } } + log::info!( + "Parsing Worktree Completed in {:?}", + t0.elapsed().as_millis() + ); } }) .detach(); + // cx.background() + // .scoped(|scope| { + // for _ in 0..cx.background().num_cpus() { + // scope.spawn(async { + // let mut parser = Parser::new(); + // let mut cursor = QueryCursor::new(); + // while let Ok((worktree_id, file_path, language, mtime)) = + // parsing_files_rx.recv().await + // { + // log::info!("Parsing File: {:?}", &file_path); + // if let Some((indexed_file, document_spans)) = Self::index_file( + // &mut cursor, + // &mut parser, + // &fs, + // language, + // file_path.clone(), + // mtime, + // ) + // .await + // .log_err() + // { + // batch_files_tx + // .try_send((worktree_id, indexed_file, document_spans)) + // .unwrap(); + // } + // } + // }); + // } + // }) + // .await; + this.update(&mut cx, |this, cx| { // The below is managing for updated on save // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is @@ -545,90 +592,90 @@ impl VectorStore { if let Some(project_state) = this.projects.get(&project.downgrade()) { let worktree_db_ids = project_state.worktree_db_ids.clone(); - if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event - { - // Iterate through changes - let language_registry = this.language_registry.clone(); + // if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event + // { + // // Iterate through changes + // let language_registry = this.language_registry.clone(); - let db = - VectorDatabase::new(this.database_url.to_string_lossy().into()); - if db.is_err() { - return; - } - let db = db.unwrap(); + // let db = + // VectorDatabase::new(this.database_url.to_string_lossy().into()); + // if db.is_err() { + // return; + // } + // let db = db.unwrap(); - let worktree_db_id: Option = { - let mut found_db_id = None; - for (w_id, db_id) in worktree_db_ids.into_iter() { - if &w_id == worktree_id { - found_db_id = Some(db_id); - } - } + // let worktree_db_id: Option = { + // let mut found_db_id = None; + // for (w_id, db_id) in worktree_db_ids.into_iter() { + // if &w_id == worktree_id { + // found_db_id = Some(db_id); + // } + // } - found_db_id - }; + // found_db_id + // }; - if worktree_db_id.is_none() { - return; - } - let worktree_db_id = worktree_db_id.unwrap(); + // if worktree_db_id.is_none() { + // return; + // } + // let worktree_db_id = worktree_db_id.unwrap(); - let file_mtimes = db.get_file_mtimes(worktree_db_id); - if file_mtimes.is_err() { - return; - } + // let file_mtimes = db.get_file_mtimes(worktree_db_id); + // if file_mtimes.is_err() { + // return; + // } - let file_mtimes = file_mtimes.unwrap(); - let paths_tx = this.paths_tx.clone(); + // let file_mtimes = file_mtimes.unwrap(); + // let paths_tx = this.paths_tx.clone(); - smol::block_on(async move { - for change in changes.into_iter() { - let change_path = change.0.clone(); - log::info!("Change: {:?}", &change_path); - if let Ok(language) = language_registry - .language_for_file(&change_path.to_path_buf(), None) - .await - { - if language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } + // smol::block_on(async move { + // for change in changes.into_iter() { + // let change_path = change.0.clone(); + // log::info!("Change: {:?}", &change_path); + // if let Ok(language) = language_registry + // .language_for_file(&change_path.to_path_buf(), None) + // .await + // { + // if language + // .grammar() + // .and_then(|grammar| grammar.embedding_config.as_ref()) + // .is_none() + // { + // continue; + // } - // TODO: Make this a bit more defensive - let modified_time = - change_path.metadata().unwrap().modified().unwrap(); - let existing_time = - file_mtimes.get(&change_path.to_path_buf()); - let already_stored = - existing_time.map_or(false, |existing_time| { - if &modified_time != existing_time - && existing_time.elapsed().unwrap().as_secs() - > REINDEXING_DELAY - { - false - } else { - true - } - }); + // // TODO: Make this a bit more defensive + // let modified_time = + // change_path.metadata().unwrap().modified().unwrap(); + // let existing_time = + // file_mtimes.get(&change_path.to_path_buf()); + // let already_stored = + // existing_time.map_or(false, |existing_time| { + // if &modified_time != existing_time + // && existing_time.elapsed().unwrap().as_secs() + // > REINDEXING_DELAY + // { + // false + // } else { + // true + // } + // }); - if !already_stored { - log::info!("Need to reindex: {:?}", &change_path); - paths_tx - .try_send(( - worktree_db_id, - change_path.to_path_buf(), - language, - modified_time, - )) - .unwrap(); - } - } - } - }) - } + // if !already_stored { + // log::info!("Need to reindex: {:?}", &change_path); + // paths_tx + // .try_send(( + // worktree_db_id, + // change_path.to_path_buf(), + // language, + // modified_time, + // )) + // .unwrap(); + // } + // } + // } + // }) + // } } });