Abstract away how database transactions are executed

Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-08-31 16:54:11 +02:00
parent 7d4d6c871b
commit 35440be98e
2 changed files with 397 additions and 432 deletions

View file

@ -12,11 +12,10 @@ use anyhow::{anyhow, Result};
use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{channel::oneshot, Future};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Buffer, Language, LanguageRegistry};
use parking_lot::Mutex;
use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
use parsing::{CodeContextRetriever, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch;
use project::{
search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId,
@ -101,13 +100,11 @@ pub fn init(
pub struct SemanticIndex {
fs: Arc<dyn Fs>,
database_url: Arc<PathBuf>,
db: VectorDatabase,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
db_update_tx: channel::Sender<DbOperation>,
parsing_files_tx: channel::Sender<PendingFile>,
_embedding_task: Task<()>,
_db_update_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
}
@ -203,32 +200,6 @@ pub struct SearchResult {
pub range: Range<Anchor>,
}
enum DbOperation {
InsertFile {
worktree_id: i64,
documents: Vec<Document>,
path: PathBuf,
mtime: SystemTime,
job_handle: JobHandle,
},
Delete {
worktree_id: i64,
path: PathBuf,
},
FindOrCreateWorktree {
path: PathBuf,
sender: oneshot::Sender<Result<i64>>,
},
FileMTimes {
worktree_id: i64,
sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
},
WorktreePreviouslyIndexed {
path: Arc<Path>,
sender: oneshot::Sender<Result<bool>>,
},
}
impl SemanticIndex {
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() {
@ -245,18 +216,14 @@ impl SemanticIndex {
async fn new(
fs: Arc<dyn Fs>,
database_url: PathBuf,
database_path: PathBuf,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>,
mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
let t0 = Instant::now();
let database_url = Arc::new(database_url);
let db = cx
.background()
.spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
.await?;
let database_path = Arc::from(database_path);
let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
log::trace!(
"db initialization took {:?} milliseconds",
@ -265,32 +232,16 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
// Perform database operations
let (db_update_tx, db_update_rx) = channel::unbounded();
let _db_update_task = cx.background().spawn({
async move {
while let Ok(job) = db_update_rx.recv().await {
Self::run_db_operation(&db, job)
}
}
});
let embedding_queue =
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db_update_tx = db_update_tx.clone();
let db = db.clone();
async move {
while let Ok(file) = embedded_files.recv().await {
db_update_tx
.try_send(DbOperation::InsertFile {
worktree_id: file.worktree_id,
documents: file.documents,
path: file.path,
mtime: file.mtime,
job_handle: file.job_handle,
})
.ok();
db.insert_file(file.worktree_id, file.path, file.mtime, file.documents)
.await
.log_err();
}
}
});
@ -325,12 +276,10 @@ impl SemanticIndex {
);
Self {
fs,
database_url,
db,
embedding_provider,
language_registry,
db_update_tx,
parsing_files_tx,
_db_update_task,
_embedding_task,
_parsing_files_tasks,
projects: HashMap::new(),
@ -338,40 +287,6 @@ impl SemanticIndex {
}))
}
fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
match job {
DbOperation::InsertFile {
worktree_id,
documents,
path,
mtime,
job_handle,
} => {
db.insert_file(worktree_id, path, mtime, documents)
.log_err();
drop(job_handle)
}
DbOperation::Delete { worktree_id, path } => {
db.delete_file(worktree_id, path).log_err();
}
DbOperation::FindOrCreateWorktree { path, sender } => {
let id = db.find_or_create_worktree(&path);
sender.send(id).ok();
}
DbOperation::FileMTimes {
worktree_id: worktree_db_id,
sender,
} => {
let file_mtimes = db.get_file_mtimes(worktree_db_id);
sender.send(file_mtimes).ok();
}
DbOperation::WorktreePreviouslyIndexed { path, sender } => {
let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
sender.send(worktree_indexed).ok();
}
}
}
async fn parse_file(
fs: &Arc<dyn Fs>,
pending_file: PendingFile,
@ -409,36 +324,6 @@ impl SemanticIndex {
}
}
fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
.try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
.unwrap();
async move { rx.await? }
}
fn get_file_mtimes(
&self,
worktree_id: i64,
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
.try_send(DbOperation::FileMTimes {
worktree_id,
sender: tx,
})
.unwrap();
async move { rx.await? }
}
fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
.try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
.unwrap();
async move { rx.await? }
}
pub fn project_previously_indexed(
&mut self,
project: ModelHandle<Project>,
@ -447,7 +332,10 @@ impl SemanticIndex {
let worktrees_indexed_previously = project
.read(cx)
.worktrees(cx)
.map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path()))
.map(|worktree| {
self.db
.worktree_previously_indexed(&worktree.read(cx).abs_path())
})
.collect::<Vec<_>>();
cx.spawn(|_, _cx| async move {
let worktree_indexed_previously =
@ -528,7 +416,8 @@ impl SemanticIndex {
.read(cx)
.worktrees(cx)
.map(|worktree| {
self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
self.db
.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
})
.collect::<Vec<_>>();
@ -559,7 +448,7 @@ impl SemanticIndex {
db_ids_by_worktree_id.insert(worktree.id(), db_id);
worktree_file_mtimes.insert(
worktree.id(),
this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id))
.await?,
);
}
@ -704,11 +593,12 @@ impl SemanticIndex {
.collect::<Vec<_>>();
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
let db_path = self.db.path().clone();
let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move {
let t0 = Instant::now();
let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
let database =
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
let phrase_embedding = embedding_provider
.embed_batch(vec![phrase])
@ -722,8 +612,9 @@ impl SemanticIndex {
t0.elapsed().as_millis()
);
let file_ids =
database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?;
let file_ids = database
.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
.await?;
let batch_n = cx.background().num_cpus();
let ids_len = file_ids.clone().len();
@ -733,27 +624,24 @@ impl SemanticIndex {
ids_len / batch_n
};
let mut result_tasks = Vec::new();
let mut batch_results = Vec::new();
for batch in file_ids.chunks(batch_size) {
let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
let limit = limit.clone();
let fs = fs.clone();
let database_url = database_url.clone();
let db_path = db_path.clone();
let phrase_embedding = phrase_embedding.clone();
let task = cx.background().spawn(async move {
let database = VectorDatabase::new(fs, database_url).await.log_err();
if database.is_none() {
return Err(anyhow!("failed to acquire database connection"));
} else {
database
.unwrap()
.top_k_search(&phrase_embedding, limit, batch.as_slice())
}
});
result_tasks.push(task);
if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
.await
.log_err()
{
batch_results.push(async move {
db.top_k_search(&phrase_embedding, limit, batch.as_slice())
.await
});
}
}
let batch_results = futures::future::join_all(result_tasks).await;
let batch_results = futures::future::join_all(batch_results).await;
let mut results = Vec::new();
for batch_result in batch_results {
@ -772,7 +660,7 @@ impl SemanticIndex {
}
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
let documents = database.get_documents_by_ids(ids.as_slice())?;
let documents = database.get_documents_by_ids(ids.as_slice()).await?;
let mut tasks = Vec::new();
let mut ranges = Vec::new();
@ -822,7 +710,8 @@ impl SemanticIndex {
cx: &mut AsyncAppContext,
) {
let mut pending_files = Vec::new();
let (language_registry, parsing_files_tx) = this.update(cx, |this, cx| {
let mut files_to_delete = Vec::new();
let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| {
if let Some(project_state) = this.projects.get_mut(&project.downgrade()) {
let outstanding_job_count_tx = &project_state.outstanding_job_count_tx;
let db_ids = &project_state.worktree_db_ids;
@ -853,12 +742,7 @@ impl SemanticIndex {
};
if info.is_deleted {
this.db_update_tx
.try_send(DbOperation::Delete {
worktree_id: worktree_db_id,
path: path.path.to_path_buf(),
})
.ok();
files_to_delete.push((worktree_db_id, path.path.to_path_buf()));
} else {
let absolute_path = worktree.read(cx).absolutize(&path.path);
let job_handle = JobHandle::new(&outstanding_job_count_tx);
@ -877,11 +761,16 @@ impl SemanticIndex {
}
(
this.db.clone(),
this.language_registry.clone(),
this.parsing_files_tx.clone(),
)
});
for (worktree_db_id, path) in files_to_delete {
db.delete_file(worktree_db_id, path).await.log_err();
}
for mut pending_file in pending_files {
if let Ok(language) = language_registry
.language_for_file(&pending_file.relative_path, None)