enabled batching for embedding calls
This commit is contained in:
parent
b6520a8f1d
commit
eff0ee3b60
1 changed files with 120 additions and 37 deletions
|
@ -22,7 +22,7 @@ use std::{
|
|||
collections::HashMap,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::SystemTime,
|
||||
time::{Instant, SystemTime},
|
||||
};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
use util::{
|
||||
|
@ -34,8 +34,9 @@ use util::{
|
|||
use workspace::{Workspace, WorkspaceCreated};
|
||||
|
||||
const REINDEXING_DELAY: u64 = 30;
|
||||
const EMBEDDINGS_BATCH_SIZE: usize = 25;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Document {
|
||||
pub offset: usize,
|
||||
pub name: String,
|
||||
|
@ -110,7 +111,7 @@ pub fn init(
|
|||
.detach();
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IndexedFile {
|
||||
path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
|
@ -126,6 +127,7 @@ pub struct VectorStore {
|
|||
paths_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>,
|
||||
_db_update_task: Task<()>,
|
||||
_paths_update_task: Task<()>,
|
||||
_embeddings_update_task: Task<()>,
|
||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||
}
|
||||
|
||||
|
@ -184,7 +186,14 @@ impl VectorStore {
|
|||
.await?;
|
||||
|
||||
Ok(cx.add_model(|cx| {
|
||||
// paths_tx -> embeddings_tx -> db_update_tx
|
||||
|
||||
let (db_update_tx, db_update_rx) = channel::unbounded();
|
||||
let (paths_tx, paths_rx) =
|
||||
channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
|
||||
let (embeddings_tx, embeddings_rx) =
|
||||
channel::unbounded::<(i64, IndexedFile, Vec<String>)>();
|
||||
|
||||
let _db_update_task = cx.background().spawn(async move {
|
||||
while let Ok(job) = db_update_rx.recv().await {
|
||||
match job {
|
||||
|
@ -192,11 +201,9 @@ impl VectorStore {
|
|||
worktree_id,
|
||||
indexed_file,
|
||||
} => {
|
||||
log::info!("Inserting File: {:?}", &indexed_file.path);
|
||||
db.insert_file(worktree_id, indexed_file).log_err();
|
||||
}
|
||||
DbWrite::Delete { worktree_id, path } => {
|
||||
log::info!("Deleting File: {:?}", &path);
|
||||
db.delete_file(worktree_id, path).log_err();
|
||||
}
|
||||
DbWrite::FindOrCreateWorktree { path, sender } => {
|
||||
|
@ -207,35 +214,116 @@ impl VectorStore {
|
|||
}
|
||||
});
|
||||
|
||||
let (paths_tx, paths_rx) =
|
||||
channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
|
||||
async fn embed_batch(
|
||||
embeddings_queue: Vec<(i64, IndexedFile, Vec<String>)>,
|
||||
embedding_provider: &Arc<dyn EmbeddingProvider>,
|
||||
db_update_tx: channel::Sender<DbWrite>,
|
||||
) -> Result<()> {
|
||||
let mut embeddings_queue = embeddings_queue.clone();
|
||||
|
||||
let mut document_spans = vec![];
|
||||
for (_, _, document_span) in embeddings_queue.clone().into_iter() {
|
||||
document_spans.extend(document_span);
|
||||
}
|
||||
|
||||
let mut embeddings = embedding_provider
|
||||
.embed_batch(document_spans.iter().map(|x| &**x).collect())
|
||||
.await?;
|
||||
|
||||
// This assumes the embeddings are returned in order
|
||||
let t0 = Instant::now();
|
||||
let mut i = 0;
|
||||
let mut j = 0;
|
||||
while let Some(embedding) = embeddings.pop() {
|
||||
// This has to accomodate for multiple indexed_files in a row without documents
|
||||
while embeddings_queue[i].1.documents.len() == j {
|
||||
i += 1;
|
||||
j = 0;
|
||||
}
|
||||
|
||||
embeddings_queue[i].1.documents[j].embedding = embedding;
|
||||
j += 1;
|
||||
}
|
||||
|
||||
for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
|
||||
// TODO: Update this so it doesnt panic
|
||||
for document in indexed_file.documents.iter() {
|
||||
assert!(
|
||||
document.embedding.len() > 0,
|
||||
"Document Embedding not Complete"
|
||||
);
|
||||
}
|
||||
|
||||
db_update_tx
|
||||
.send(DbWrite::InsertFile {
|
||||
worktree_id,
|
||||
indexed_file,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
|
||||
let embedding_provider_clone = embedding_provider.clone();
|
||||
|
||||
let db_update_tx_clone = db_update_tx.clone();
|
||||
let _embeddings_update_task = cx.background().spawn(async move {
|
||||
let mut queue_len = 0;
|
||||
let mut embeddings_queue = vec![];
|
||||
let mut request_count = 0;
|
||||
while let Ok((worktree_id, indexed_file, document_spans)) =
|
||||
embeddings_rx.recv().await
|
||||
{
|
||||
queue_len += &document_spans.len();
|
||||
embeddings_queue.push((worktree_id, indexed_file, document_spans));
|
||||
|
||||
if queue_len >= EMBEDDINGS_BATCH_SIZE {
|
||||
let _ = embed_batch(
|
||||
embeddings_queue,
|
||||
&embedding_provider_clone,
|
||||
db_update_tx_clone.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
embeddings_queue = vec![];
|
||||
queue_len = 0;
|
||||
|
||||
request_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if queue_len > 0 {
|
||||
let _ = embed_batch(
|
||||
embeddings_queue,
|
||||
&embedding_provider_clone,
|
||||
db_update_tx_clone.clone(),
|
||||
)
|
||||
.await;
|
||||
request_count += 1;
|
||||
}
|
||||
});
|
||||
|
||||
let fs_clone = fs.clone();
|
||||
let db_update_tx_clone = db_update_tx.clone();
|
||||
let embedding_provider_clone = embedding_provider.clone();
|
||||
|
||||
let _paths_update_task = cx.background().spawn(async move {
|
||||
let mut parser = Parser::new();
|
||||
let mut cursor = QueryCursor::new();
|
||||
while let Ok((worktree_id, file_path, language, mtime)) = paths_rx.recv().await {
|
||||
log::info!("Parsing File: {:?}", &file_path);
|
||||
if let Some(indexed_file) = Self::index_file(
|
||||
if let Some((indexed_file, document_spans)) = Self::index_file(
|
||||
&mut cursor,
|
||||
&mut parser,
|
||||
embedding_provider_clone.as_ref(),
|
||||
&fs_clone,
|
||||
language,
|
||||
file_path,
|
||||
file_path.clone(),
|
||||
mtime,
|
||||
)
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
db_update_tx_clone
|
||||
.try_send(DbWrite::InsertFile {
|
||||
worktree_id,
|
||||
indexed_file,
|
||||
})
|
||||
embeddings_tx
|
||||
.try_send((worktree_id, indexed_file, document_spans))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
@ -251,6 +339,7 @@ impl VectorStore {
|
|||
projects: HashMap::new(),
|
||||
_db_update_task,
|
||||
_paths_update_task,
|
||||
_embeddings_update_task,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
@ -258,12 +347,11 @@ impl VectorStore {
|
|||
async fn index_file(
|
||||
cursor: &mut QueryCursor,
|
||||
parser: &mut Parser,
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
fs: &Arc<dyn Fs>,
|
||||
language: Arc<Language>,
|
||||
file_path: PathBuf,
|
||||
mtime: SystemTime,
|
||||
) -> Result<IndexedFile> {
|
||||
) -> Result<(IndexedFile, Vec<String>)> {
|
||||
let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
|
||||
let embedding_config = grammar
|
||||
.embedding_config
|
||||
|
@ -298,7 +386,7 @@ impl VectorStore {
|
|||
if let Some((item, name)) =
|
||||
content.get(item_range.clone()).zip(content.get(name_range))
|
||||
{
|
||||
context_spans.push(item);
|
||||
context_spans.push(item.to_string());
|
||||
documents.push(Document {
|
||||
name: name.to_string(),
|
||||
offset: item_range.start,
|
||||
|
@ -308,18 +396,14 @@ impl VectorStore {
|
|||
}
|
||||
}
|
||||
|
||||
if !documents.is_empty() {
|
||||
let embeddings = embedding_provider.embed_batch(context_spans).await?;
|
||||
for (document, embedding) in documents.iter_mut().zip(embeddings) {
|
||||
document.embedding = embedding;
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(IndexedFile {
|
||||
return Ok((
|
||||
IndexedFile {
|
||||
path: file_path,
|
||||
mtime,
|
||||
documents,
|
||||
});
|
||||
},
|
||||
context_spans,
|
||||
));
|
||||
}
|
||||
|
||||
fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
|
||||
|
@ -454,6 +538,9 @@ impl VectorStore {
|
|||
.detach();
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
// The below is managing for updated on save
|
||||
// Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
|
||||
// greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
|
||||
let _subscription = cx.subscribe(&project, |this, project, event, _cx| {
|
||||
if let Some(project_state) = this.projects.get(&project.downgrade()) {
|
||||
let worktree_db_ids = project_state.worktree_db_ids.clone();
|
||||
|
@ -554,8 +641,6 @@ impl VectorStore {
|
|||
);
|
||||
});
|
||||
|
||||
log::info!("Semantic Indexing Complete!");
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
|
@ -591,8 +676,6 @@ impl VectorStore {
|
|||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
log::info!("Searching for: {:?}", phrase);
|
||||
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let database_url = self.database_url.clone();
|
||||
cx.spawn(|this, cx| async move {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue