moved semantic index to use embeddings queue to batch and managed for atomic database writes

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
KCaverly 2023-08-30 16:58:45 -04:00
parent 76ce52df4e
commit 5abad58b0d
3 changed files with 55 additions and 222 deletions

View file

@ -1,10 +1,8 @@
use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime}; use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
use gpui::executor::Background;
use gpui::AppContext;
use parking_lot::Mutex; use parking_lot::Mutex;
use smol::channel; use smol::channel;
use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
#[derive(Clone)] #[derive(Clone)]
pub struct FileToEmbed { pub struct FileToEmbed {
@ -38,6 +36,7 @@ impl PartialEq for FileToEmbed {
pub struct EmbeddingQueue { pub struct EmbeddingQueue {
embedding_provider: Arc<dyn EmbeddingProvider>, embedding_provider: Arc<dyn EmbeddingProvider>,
pending_batch: Vec<FileToEmbedFragment>, pending_batch: Vec<FileToEmbedFragment>,
executor: Arc<Background>,
pending_batch_token_count: usize, pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>, finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>, finished_files_rx: channel::Receiver<FileToEmbed>,
@ -49,10 +48,11 @@ pub struct FileToEmbedFragment {
} }
impl EmbeddingQueue { impl EmbeddingQueue {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self { pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded(); let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self { Self {
embedding_provider, embedding_provider,
executor,
pending_batch: Vec::new(), pending_batch: Vec::new(),
pending_batch_token_count: 0, pending_batch_token_count: 0,
finished_files_tx, finished_files_tx,
@ -60,7 +60,12 @@ impl EmbeddingQueue {
} }
} }
pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) { pub fn push(&mut self, file: FileToEmbed) {
if file.documents.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
return;
}
let file = Arc::new(Mutex::new(file)); let file = Arc::new(Mutex::new(file));
self.pending_batch.push(FileToEmbedFragment { self.pending_batch.push(FileToEmbedFragment {
@ -73,7 +78,7 @@ impl EmbeddingQueue {
let next_token_count = self.pending_batch_token_count + document.token_count; let next_token_count = self.pending_batch_token_count + document.token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() { if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end; let range_end = fragment_range.end;
self.flush(cx); self.flush();
self.pending_batch.push(FileToEmbedFragment { self.pending_batch.push(FileToEmbedFragment {
file: file.clone(), file: file.clone(),
document_range: range_end..range_end, document_range: range_end..range_end,
@ -86,7 +91,7 @@ impl EmbeddingQueue {
} }
} }
pub fn flush(&mut self, cx: &mut AppContext) { pub fn flush(&mut self) {
let batch = mem::take(&mut self.pending_batch); let batch = mem::take(&mut self.pending_batch);
self.pending_batch_token_count = 0; self.pending_batch_token_count = 0;
if batch.is_empty() { if batch.is_empty() {
@ -95,7 +100,7 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone(); let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone(); let embedding_provider = self.embedding_provider.clone();
cx.background().spawn(async move { self.executor.spawn(async move {
let mut spans = Vec::new(); let mut spans = Vec::new();
for fragment in &batch { for fragment in &batch {
let file = fragment.file.lock(); let file = fragment.file.lock();

View file

@ -1,5 +1,6 @@
mod db; mod db;
mod embedding; mod embedding;
mod embedding_queue;
mod parsing; mod parsing;
pub mod semantic_index_settings; pub mod semantic_index_settings;
@ -10,6 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use db::VectorDatabase; use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{channel::oneshot, Future}; use futures::{channel::oneshot, Future};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Buffer, Language, LanguageRegistry}; use language::{Anchor, Buffer, Language, LanguageRegistry};
@ -23,7 +25,6 @@ use smol::channel;
use std::{ use std::{
cmp::Ordering, cmp::Ordering,
collections::{BTreeMap, HashMap}, collections::{BTreeMap, HashMap},
mem,
ops::Range, ops::Range,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::{Arc, Weak}, sync::{Arc, Weak},
@ -38,7 +39,6 @@ use util::{
use workspace::WorkspaceCreated; use workspace::WorkspaceCreated;
const SEMANTIC_INDEX_VERSION: usize = 7; const SEMANTIC_INDEX_VERSION: usize = 7;
const EMBEDDINGS_BATCH_SIZE: usize = 80;
const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600); const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(600);
pub fn init( pub fn init(
@ -106,9 +106,8 @@ pub struct SemanticIndex {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
db_update_tx: channel::Sender<DbOperation>, db_update_tx: channel::Sender<DbOperation>,
parsing_files_tx: channel::Sender<PendingFile>, parsing_files_tx: channel::Sender<PendingFile>,
_embedding_task: Task<()>,
_db_update_task: Task<()>, _db_update_task: Task<()>,
_embed_batch_tasks: Vec<Task<()>>,
_batch_files_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>, _parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>, projects: HashMap<WeakModelHandle<Project>, ProjectState>,
} }
@ -128,7 +127,7 @@ struct ChangedPathInfo {
} }
#[derive(Clone)] #[derive(Clone)]
struct JobHandle { pub struct JobHandle {
/// The outer Arc is here to count the clones of a JobHandle instance; /// The outer Arc is here to count the clones of a JobHandle instance;
/// when the last handle to a given job is dropped, we decrement a counter (just once). /// when the last handle to a given job is dropped, we decrement a counter (just once).
tx: Arc<Weak<Mutex<watch::Sender<usize>>>>, tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
@ -230,17 +229,6 @@ enum DbOperation {
}, },
} }
enum EmbeddingJob {
Enqueue {
worktree_id: i64,
path: PathBuf,
mtime: SystemTime,
documents: Vec<Document>,
job_handle: JobHandle,
},
Flush,
}
impl SemanticIndex { impl SemanticIndex {
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> { pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() { if cx.has_global::<ModelHandle<Self>>() {
@ -287,52 +275,35 @@ impl SemanticIndex {
} }
}); });
// Group documents into batches and send them to the embedding provider. let embedding_queue =
let (embed_batch_tx, embed_batch_rx) = EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>(); let _embedding_task = cx.background().spawn({
let mut _embed_batch_tasks = Vec::new(); let embedded_files = embedding_queue.finished_files();
for _ in 0..cx.background().num_cpus() {
let embed_batch_rx = embed_batch_rx.clone();
_embed_batch_tasks.push(cx.background().spawn({
let db_update_tx = db_update_tx.clone(); let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
async move { async move {
while let Ok(embeddings_queue) = embed_batch_rx.recv().await { while let Ok(file) = embedded_files.recv().await {
Self::compute_embeddings_for_batch( db_update_tx
embeddings_queue, .try_send(DbOperation::InsertFile {
&embedding_provider, worktree_id: file.worktree_id,
&db_update_tx, documents: file.documents,
) path: file.path,
.await; mtime: file.mtime,
job_handle: file.job_handle,
})
.ok();
} }
} }
}));
}
// Group documents into batches and send them to the embedding provider.
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
let _batch_files_task = cx.background().spawn(async move {
let mut queue_len = 0;
let mut embeddings_queue = vec![];
while let Ok(job) = batch_files_rx.recv().await {
Self::enqueue_documents_to_embed(
job,
&mut queue_len,
&mut embeddings_queue,
&embed_batch_tx,
);
}
}); });
// Parse files into embeddable documents. // Parse files into embeddable documents.
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>(); let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
let embedding_queue = Arc::new(Mutex::new(embedding_queue));
let mut _parsing_files_tasks = Vec::new(); let mut _parsing_files_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() { for _ in 0..cx.background().num_cpus() {
let fs = fs.clone(); let fs = fs.clone();
let parsing_files_rx = parsing_files_rx.clone(); let parsing_files_rx = parsing_files_rx.clone();
let batch_files_tx = batch_files_tx.clone();
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone(); let embedding_provider = embedding_provider.clone();
let embedding_queue = embedding_queue.clone();
_parsing_files_tasks.push(cx.background().spawn(async move { _parsing_files_tasks.push(cx.background().spawn(async move {
let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
while let Ok(pending_file) = parsing_files_rx.recv().await { while let Ok(pending_file) = parsing_files_rx.recv().await {
@ -340,9 +311,8 @@ impl SemanticIndex {
&fs, &fs,
pending_file, pending_file,
&mut retriever, &mut retriever,
&batch_files_tx, &embedding_queue,
&parsing_files_rx, &parsing_files_rx,
&db_update_tx,
) )
.await; .await;
} }
@ -361,8 +331,7 @@ impl SemanticIndex {
db_update_tx, db_update_tx,
parsing_files_tx, parsing_files_tx,
_db_update_task, _db_update_task,
_embed_batch_tasks, _embedding_task,
_batch_files_task,
_parsing_files_tasks, _parsing_files_tasks,
projects: HashMap::new(), projects: HashMap::new(),
} }
@ -403,136 +372,12 @@ impl SemanticIndex {
} }
} }
async fn compute_embeddings_for_batch(
mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
embedding_provider: &Arc<dyn EmbeddingProvider>,
db_update_tx: &channel::Sender<DbOperation>,
) {
let mut batch_documents = vec![];
for (_, documents, _, _, _) in embeddings_queue.iter() {
batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
}
if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
log::trace!(
"created {} embeddings for {} files",
embeddings.len(),
embeddings_queue.len(),
);
let mut i = 0;
let mut j = 0;
for embedding in embeddings.iter() {
while embeddings_queue[i].1.len() == j {
i += 1;
j = 0;
}
embeddings_queue[i].1[j].embedding = embedding.to_owned();
j += 1;
}
for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
db_update_tx
.send(DbOperation::InsertFile {
worktree_id,
documents,
path,
mtime,
job_handle,
})
.await
.unwrap();
}
} else {
// Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
db_update_tx
.send(DbOperation::InsertFile {
worktree_id,
documents: vec![],
path,
mtime,
job_handle,
})
.await
.unwrap();
}
}
}
fn enqueue_documents_to_embed(
job: EmbeddingJob,
queue_len: &mut usize,
embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
) {
// Handle edge case where individual file has more documents than max batch size
let should_flush = match job {
EmbeddingJob::Enqueue {
documents,
worktree_id,
path,
mtime,
job_handle,
} => {
// If documents is greater than embeddings batch size, recursively batch existing rows.
if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
let first_job = EmbeddingJob::Enqueue {
documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
worktree_id,
path: path.clone(),
mtime,
job_handle: job_handle.clone(),
};
Self::enqueue_documents_to_embed(
first_job,
queue_len,
embeddings_queue,
embed_batch_tx,
);
let second_job = EmbeddingJob::Enqueue {
documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
worktree_id,
path: path.clone(),
mtime,
job_handle: job_handle.clone(),
};
Self::enqueue_documents_to_embed(
second_job,
queue_len,
embeddings_queue,
embed_batch_tx,
);
return;
} else {
*queue_len += &documents.len();
embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
*queue_len >= EMBEDDINGS_BATCH_SIZE
}
}
EmbeddingJob::Flush => true,
};
if should_flush {
embed_batch_tx
.try_send(mem::take(embeddings_queue))
.unwrap();
*queue_len = 0;
}
}
async fn parse_file( async fn parse_file(
fs: &Arc<dyn Fs>, fs: &Arc<dyn Fs>,
pending_file: PendingFile, pending_file: PendingFile,
retriever: &mut CodeContextRetriever, retriever: &mut CodeContextRetriever,
batch_files_tx: &channel::Sender<EmbeddingJob>, embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
parsing_files_rx: &channel::Receiver<PendingFile>, parsing_files_rx: &channel::Receiver<PendingFile>,
db_update_tx: &channel::Sender<DbOperation>,
) { ) {
let Some(language) = pending_file.language else { let Some(language) = pending_file.language else {
return; return;
@ -549,33 +394,18 @@ impl SemanticIndex {
documents.len() documents.len()
); );
if documents.len() == 0 { embedding_queue.lock().push(FileToEmbed {
db_update_tx
.send(DbOperation::InsertFile {
worktree_id: pending_file.worktree_db_id,
documents,
path: pending_file.relative_path,
mtime: pending_file.modified_time,
job_handle: pending_file.job_handle,
})
.await
.unwrap();
} else {
batch_files_tx
.try_send(EmbeddingJob::Enqueue {
worktree_id: pending_file.worktree_db_id, worktree_id: pending_file.worktree_db_id,
path: pending_file.relative_path, path: pending_file.relative_path,
mtime: pending_file.modified_time, mtime: pending_file.modified_time,
job_handle: pending_file.job_handle, job_handle: pending_file.job_handle,
documents, documents,
}) });
.unwrap();
}
} }
} }
if parsing_files_rx.len() == 0 { if parsing_files_rx.len() == 0 {
batch_files_tx.try_send(EmbeddingJob::Flush).unwrap(); embedding_queue.lock().flush();
} }
} }
@ -881,7 +711,7 @@ impl SemanticIndex {
let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
let phrase_embedding = embedding_provider let phrase_embedding = embedding_provider
.embed_batch(vec![&phrase]) .embed_batch(vec![phrase])
.await? .await?
.into_iter() .into_iter()
.next() .next()

View file

@ -235,17 +235,15 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone());
let finished_files = cx.update(|cx| { let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files { for file in &files {
queue.push(file.clone(), cx); queue.push(file.clone());
} }
queue.flush(cx); queue.flush();
queue.finished_files()
});
cx.foreground().run_until_parked(); cx.foreground().run_until_parked();
let finished_files = queue.finished_files();
let mut embedded_files: Vec<_> = files let mut embedded_files: Vec<_> = files
.iter() .iter()
.map(|_| finished_files.try_recv().expect("no finished file")) .map(|_| finished_files.try_recv().expect("no finished file"))