enabled batching for embedding calls

This commit is contained in:
KCaverly 2023-07-05 10:02:42 -04:00
parent b6520a8f1d
commit eff0ee3b60

View file

@ -22,7 +22,7 @@ use std::{
collections::HashMap, collections::HashMap,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::Arc,
time::SystemTime, time::{Instant, SystemTime},
}; };
use tree_sitter::{Parser, QueryCursor}; use tree_sitter::{Parser, QueryCursor};
use util::{ use util::{
@ -34,8 +34,9 @@ use util::{
use workspace::{Workspace, WorkspaceCreated}; use workspace::{Workspace, WorkspaceCreated};
const REINDEXING_DELAY: u64 = 30; const REINDEXING_DELAY: u64 = 30;
const EMBEDDINGS_BATCH_SIZE: usize = 25;
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Document { pub struct Document {
pub offset: usize, pub offset: usize,
pub name: String, pub name: String,
@ -110,7 +111,7 @@ pub fn init(
.detach(); .detach();
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct IndexedFile { pub struct IndexedFile {
path: PathBuf, path: PathBuf,
mtime: SystemTime, mtime: SystemTime,
@ -126,6 +127,7 @@ pub struct VectorStore {
paths_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>, paths_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>,
_db_update_task: Task<()>, _db_update_task: Task<()>,
_paths_update_task: Task<()>, _paths_update_task: Task<()>,
_embeddings_update_task: Task<()>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>, projects: HashMap<WeakModelHandle<Project>, ProjectState>,
} }
@ -184,7 +186,14 @@ impl VectorStore {
.await?; .await?;
Ok(cx.add_model(|cx| { Ok(cx.add_model(|cx| {
// paths_tx -> embeddings_tx -> db_update_tx
let (db_update_tx, db_update_rx) = channel::unbounded(); let (db_update_tx, db_update_rx) = channel::unbounded();
let (paths_tx, paths_rx) =
channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
let (embeddings_tx, embeddings_rx) =
channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
let _db_update_task = cx.background().spawn(async move { let _db_update_task = cx.background().spawn(async move {
while let Ok(job) = db_update_rx.recv().await { while let Ok(job) = db_update_rx.recv().await {
match job { match job {
@ -192,11 +201,9 @@ impl VectorStore {
worktree_id, worktree_id,
indexed_file, indexed_file,
} => { } => {
log::info!("Inserting File: {:?}", &indexed_file.path);
db.insert_file(worktree_id, indexed_file).log_err(); db.insert_file(worktree_id, indexed_file).log_err();
} }
DbWrite::Delete { worktree_id, path } => { DbWrite::Delete { worktree_id, path } => {
log::info!("Deleting File: {:?}", &path);
db.delete_file(worktree_id, path).log_err(); db.delete_file(worktree_id, path).log_err();
} }
DbWrite::FindOrCreateWorktree { path, sender } => { DbWrite::FindOrCreateWorktree { path, sender } => {
@ -207,35 +214,116 @@ impl VectorStore {
} }
}); });
let (paths_tx, paths_rx) = async fn embed_batch(
channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>(); embeddings_queue: Vec<(i64, IndexedFile, Vec<String>)>,
embedding_provider: &Arc<dyn EmbeddingProvider>,
db_update_tx: channel::Sender<DbWrite>,
) -> Result<()> {
let mut embeddings_queue = embeddings_queue.clone();
let mut document_spans = vec![];
for (_, _, document_span) in embeddings_queue.clone().into_iter() {
document_spans.extend(document_span);
}
let mut embeddings = embedding_provider
.embed_batch(document_spans.iter().map(|x| &**x).collect())
.await?;
// This assumes the embeddings are returned in order
let t0 = Instant::now();
let mut i = 0;
let mut j = 0;
while let Some(embedding) = embeddings.pop() {
// This has to accomodate for multiple indexed_files in a row without documents
while embeddings_queue[i].1.documents.len() == j {
i += 1;
j = 0;
}
embeddings_queue[i].1.documents[j].embedding = embedding;
j += 1;
}
for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
// TODO: Update this so it doesnt panic
for document in indexed_file.documents.iter() {
assert!(
document.embedding.len() > 0,
"Document Embedding not Complete"
);
}
db_update_tx
.send(DbWrite::InsertFile {
worktree_id,
indexed_file,
})
.await
.unwrap();
}
anyhow::Ok(())
}
let embedding_provider_clone = embedding_provider.clone();
let db_update_tx_clone = db_update_tx.clone();
let _embeddings_update_task = cx.background().spawn(async move {
let mut queue_len = 0;
let mut embeddings_queue = vec![];
let mut request_count = 0;
while let Ok((worktree_id, indexed_file, document_spans)) =
embeddings_rx.recv().await
{
queue_len += &document_spans.len();
embeddings_queue.push((worktree_id, indexed_file, document_spans));
if queue_len >= EMBEDDINGS_BATCH_SIZE {
let _ = embed_batch(
embeddings_queue,
&embedding_provider_clone,
db_update_tx_clone.clone(),
)
.await;
embeddings_queue = vec![];
queue_len = 0;
request_count += 1;
}
}
if queue_len > 0 {
let _ = embed_batch(
embeddings_queue,
&embedding_provider_clone,
db_update_tx_clone.clone(),
)
.await;
request_count += 1;
}
});
let fs_clone = fs.clone(); let fs_clone = fs.clone();
let db_update_tx_clone = db_update_tx.clone();
let embedding_provider_clone = embedding_provider.clone();
let _paths_update_task = cx.background().spawn(async move { let _paths_update_task = cx.background().spawn(async move {
let mut parser = Parser::new(); let mut parser = Parser::new();
let mut cursor = QueryCursor::new(); let mut cursor = QueryCursor::new();
while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await { while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await {
log::info!("Parsing File: {:?}", &file_path); if let Some((indexed_file, document_spans)) = Self::index_file(
if let Some(indexed_file) = Self::index_file(
&mut cursor, &mut cursor,
&mut parser, &mut parser,
embedding_provider_clone.as_ref(),
&fs_clone, &fs_clone,
language, language,
file_path, file_path.clone(),
mtime, mtime,
) )
.await .await
.log_err() .log_err()
{ {
db_update_tx_clone embeddings_tx
.try_send(DbWrite::InsertFile { .try_send((worktree_id, indexed_file, document_spans))
worktree_id,
indexed_file,
})
.unwrap(); .unwrap();
} }
} }
@ -251,6 +339,7 @@ impl VectorStore {
projects: HashMap::new(), projects: HashMap::new(),
_db_update_task, _db_update_task,
_paths_update_task, _paths_update_task,
_embeddings_update_task,
} }
})) }))
} }
@ -258,12 +347,11 @@ impl VectorStore {
async fn index_file( async fn index_file(
cursor: &mut QueryCursor, cursor: &mut QueryCursor,
parser: &mut Parser, parser: &mut Parser,
embedding_provider: &dyn EmbeddingProvider,
fs: &Arc<dyn Fs>, fs: &Arc<dyn Fs>,
language: Arc<Language>, language: Arc<Language>,
file_path: PathBuf, file_path: PathBuf,
mtime: SystemTime, mtime: SystemTime,
) -> Result<IndexedFile> { ) -> Result<(IndexedFile, Vec<String>)> {
let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
let embedding_config = grammar let embedding_config = grammar
.embedding_config .embedding_config
@ -298,7 +386,7 @@ impl VectorStore {
if let Some((item, name)) = if let Some((item, name)) =
content.get(item_range.clone()).zip(content.get(name_range)) content.get(item_range.clone()).zip(content.get(name_range))
{ {
context_spans.push(item); context_spans.push(item.to_string());
documents.push(Document { documents.push(Document {
name: name.to_string(), name: name.to_string(),
offset: item_range.start, offset: item_range.start,
@ -308,18 +396,14 @@ impl VectorStore {
} }
} }
if !documents.is_empty() { return Ok((
let embeddings = embedding_provider.embed_batch(context_spans).await?; IndexedFile {
for (document, embedding) in documents.iter_mut().zip(embeddings) { path: file_path,
document.embedding = embedding; mtime,
} documents,
} },
context_spans,
return Ok(IndexedFile { ));
path: file_path,
mtime,
documents,
});
} }
fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> { fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
@ -454,6 +538,9 @@ impl VectorStore {
.detach(); .detach();
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
// The below is managing for updated on save
// Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
// greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
let _subscription = cx.subscribe(&project, |this, project, event, _cx| { let _subscription = cx.subscribe(&project, |this, project, event, _cx| {
if let Some(project_state) = this.projects.get(&project.downgrade()) { if let Some(project_state) = this.projects.get(&project.downgrade()) {
let worktree_db_ids = project_state.worktree_db_ids.clone(); let worktree_db_ids = project_state.worktree_db_ids.clone();
@ -554,8 +641,6 @@ impl VectorStore {
); );
}); });
log::info!("Semantic Indexing Complete!");
anyhow::Ok(()) anyhow::Ok(())
}) })
} }
@ -591,8 +676,6 @@ impl VectorStore {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
log::info!("Searching for: {:?}", phrase);
let embedding_provider = self.embedding_provider.clone(); let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone(); let database_url = self.database_url.clone();
cx.spawn(|this, cx| async move { cx.spawn(|this, cx| async move {