Abstract away how database transactions are executed

Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-08-31 16:54:11 +02:00
parent 7d4d6c871b
commit 35440be98e
2 changed files with 397 additions and 432 deletions

View file

@ -1,5 +1,7 @@
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION}; use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use futures::channel::oneshot;
use gpui::executor;
use project::{search::PathMatcher, Fs}; use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp; use rpc::proto::Timestamp;
use rusqlite::{ use rusqlite::{
@ -9,12 +11,14 @@ use rusqlite::{
use std::{ use std::{
cmp::Ordering, cmp::Ordering,
collections::HashMap, collections::HashMap,
future::Future,
ops::Range, ops::Range,
path::{Path, PathBuf}, path::{Path, PathBuf},
rc::Rc, rc::Rc,
sync::Arc, sync::Arc,
time::SystemTime, time::SystemTime,
}; };
use util::TryFutureExt;
#[derive(Debug)] #[derive(Debug)]
pub struct FileRecord { pub struct FileRecord {
@ -51,72 +55,109 @@ impl FromSql for Sha1 {
} }
} }
#[derive(Clone)]
pub struct VectorDatabase { pub struct VectorDatabase {
db: rusqlite::Connection, path: Arc<Path>,
transactions: smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&rusqlite::Connection)>>,
} }
impl VectorDatabase { impl VectorDatabase {
pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> { pub async fn new(
fs: Arc<dyn Fs>,
path: Arc<Path>,
executor: Arc<executor::Background>,
) -> Result<Self> {
if let Some(db_directory) = path.parent() { if let Some(db_directory) = path.parent() {
fs.create_dir(db_directory).await?; fs.create_dir(db_directory).await?;
} }
let (transactions_tx, transactions_rx) =
smol::channel::unbounded::<Box<dyn 'static + Send + FnOnce(&rusqlite::Connection)>>();
executor
.spawn({
let path = path.clone();
async move {
let connection = rusqlite::Connection::open(&path)?;
while let Ok(transaction) = transactions_rx.recv().await {
transaction(&connection);
}
anyhow::Ok(())
}
.log_err()
})
.detach();
let this = Self { let this = Self {
db: rusqlite::Connection::open(path.as_path())?, transactions: transactions_tx,
path,
}; };
this.initialize_database()?; this.initialize_database().await?;
Ok(this) Ok(this)
} }
fn get_existing_version(&self) -> Result<i64> { pub fn path(&self) -> &Arc<Path> {
let mut version_query = self &self.path
.db
.prepare("SELECT version from semantic_index_config")?;
version_query
.query_row([], |row| Ok(row.get::<_, i64>(0)?))
.map_err(|err| anyhow!("version query failed: {err}"))
} }
fn initialize_database(&self) -> Result<()> { fn transact<F, T>(&self, transaction: F) -> impl Future<Output = Result<T>>
rusqlite::vtab::array::load_module(&self.db)?; where
F: 'static + Send + FnOnce(&rusqlite::Connection) -> Result<T>,
T: 'static + Send,
{
let (tx, rx) = oneshot::channel();
let transactions = self.transactions.clone();
async move {
if transactions
.send(Box::new(|connection| {
let result = transaction(connection);
let _ = tx.send(result);
}))
.await
.is_err()
{
return Err(anyhow!("connection was dropped"))?;
}
rx.await?
}
}
fn initialize_database(&self) -> impl Future<Output = Result<()>> {
self.transact(|db| {
rusqlite::vtab::array::load_module(&db)?;
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
if self let version_query = db.prepare("SELECT version from semantic_index_config");
.get_existing_version() let version = version_query
.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
{ if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
log::trace!("vector database schema up to date"); log::trace!("vector database schema up to date");
return Ok(()); return Ok(());
} }
log::trace!("vector database schema out of date. updating..."); log::trace!("vector database schema out of date. updating...");
self.db db.execute("DROP TABLE IF EXISTS documents", [])
.execute("DROP TABLE IF EXISTS documents", [])
.context("failed to drop 'documents' table")?; .context("failed to drop 'documents' table")?;
self.db db.execute("DROP TABLE IF EXISTS files", [])
.execute("DROP TABLE IF EXISTS files", [])
.context("failed to drop 'files' table")?; .context("failed to drop 'files' table")?;
self.db db.execute("DROP TABLE IF EXISTS worktrees", [])
.execute("DROP TABLE IF EXISTS worktrees", [])
.context("failed to drop 'worktrees' table")?; .context("failed to drop 'worktrees' table")?;
self.db db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
.execute("DROP TABLE IF EXISTS semantic_index_config", [])
.context("failed to drop 'semantic_index_config' table")?; .context("failed to drop 'semantic_index_config' table")?;
// Initialize Vector Databasing Tables // Initialize Vector Databasing Tables
self.db.execute( db.execute(
"CREATE TABLE semantic_index_config ( "CREATE TABLE semantic_index_config (
version INTEGER NOT NULL version INTEGER NOT NULL
)", )",
[], [],
)?; )?;
self.db.execute( db.execute(
"INSERT INTO semantic_index_config (version) VALUES (?1)", "INSERT INTO semantic_index_config (version) VALUES (?1)",
params![SEMANTIC_INDEX_VERSION], params![SEMANTIC_INDEX_VERSION],
)?; )?;
self.db.execute( db.execute(
"CREATE TABLE worktrees ( "CREATE TABLE worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL absolute_path VARCHAR NOT NULL
@ -126,7 +167,7 @@ impl VectorDatabase {
[], [],
)?; )?;
self.db.execute( db.execute(
"CREATE TABLE files ( "CREATE TABLE files (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL, worktree_id INTEGER NOT NULL,
@ -138,7 +179,7 @@ impl VectorDatabase {
[], [],
)?; )?;
self.db.execute( db.execute(
"CREATE TABLE documents ( "CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL, file_id INTEGER NOT NULL,
@ -154,14 +195,21 @@ impl VectorDatabase {
log::trace!("vector database initialized with updated schema."); log::trace!("vector database initialized with updated schema.");
Ok(()) Ok(())
})
} }
pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> { pub fn delete_file(
self.db.execute( &self,
worktree_id: i64,
delete_path: PathBuf,
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
params![worktree_id, delete_path.to_str()], params![worktree_id, delete_path.to_str()],
)?; )?;
Ok(()) Ok(())
})
} }
pub fn insert_file( pub fn insert_file(
@ -170,27 +218,29 @@ impl VectorDatabase {
path: PathBuf, path: PathBuf,
mtime: SystemTime, mtime: SystemTime,
documents: Vec<Document>, documents: Vec<Document>,
) -> Result<()> { ) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
// Return the existing ID, if both the file and mtime match // Return the existing ID, if both the file and mtime match
let mtime = Timestamp::from(mtime); let mtime = Timestamp::from(mtime);
let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
let mut existing_id_query = db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
let existing_id = existing_id_query let existing_id = existing_id_query
.query_row( .query_row(
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
|row| Ok(row.get::<_, i64>(0)?), |row| Ok(row.get::<_, i64>(0)?),
) );
.map_err(|err| anyhow!(err));
let file_id = if existing_id.is_ok() { let file_id = if existing_id.is_ok() {
// If already exists, just return the existing id // If already exists, just return the existing id
existing_id.unwrap() existing_id?
} else { } else {
// Delete Existing Row // Delete Existing Row
self.db.execute( db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;", "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
params![worktree_id, path.to_str()], params![worktree_id, path.to_str()],
)?; )?;
self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?; db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
self.db.last_insert_rowid() db.last_insert_rowid()
}; };
// Currently inserting at approximately 3400 documents a second // Currently inserting at approximately 3400 documents a second
@ -199,7 +249,7 @@ impl VectorDatabase {
let embedding_blob = bincode::serialize(&document.embedding)?; let embedding_blob = bincode::serialize(&document.embedding)?;
let sha_blob = bincode::serialize(&document.sha1)?; let sha_blob = bincode::serialize(&document.sha1)?;
self.db.execute( db.execute(
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![ params![
file_id, file_id,
@ -213,53 +263,59 @@ impl VectorDatabase {
} }
Ok(()) Ok(())
})
} }
pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> { pub fn worktree_previously_indexed(
let mut worktree_query = self &self,
.db worktree_root_path: &Path,
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; ) -> impl Future<Output = Result<bool>> {
let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| { .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
Ok(row.get::<_, i64>(0)?)
})
.map_err(|err| anyhow!(err));
if worktree_id.is_ok() { if worktree_id.is_ok() {
return Ok(true); return Ok(true);
} else { } else {
return Ok(false); return Ok(false);
} }
})
} }
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> { pub fn find_or_create_worktree(
// Check that the absolute path doesnt exist &self,
let mut worktree_query = self worktree_root_path: PathBuf,
.db ) -> impl Future<Output = Result<i64>> {
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| { .query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?) Ok(row.get::<_, i64>(0)?)
}) });
.map_err(|err| anyhow!(err));
if worktree_id.is_ok() { if worktree_id.is_ok() {
return worktree_id; return Ok(worktree_id?);
} }
// If worktree_id is Err, insert new worktree // If worktree_id is Err, insert new worktree
self.db.execute( db.execute(
" "INSERT into worktrees (absolute_path) VALUES (?1)",
INSERT into worktrees (absolute_path) VALUES (?1)
",
params![worktree_root_path.to_string_lossy()], params![worktree_root_path.to_string_lossy()],
)?; )?;
Ok(self.db.last_insert_rowid()) Ok(db.last_insert_rowid())
})
} }
pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> { pub fn get_file_mtimes(
let mut statement = self.db.prepare( &self,
worktree_id: i64,
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
self.transact(move |db| {
let mut statement = db.prepare(
" "
SELECT relative_path, mtime_seconds, mtime_nanos SELECT relative_path, mtime_seconds, mtime_nanos
FROM files FROM files
@ -281,6 +337,7 @@ impl VectorDatabase {
result.insert(row.0, row.1); result.insert(row.0, row.1);
} }
Ok(result) Ok(result)
})
} }
pub fn top_k_search( pub fn top_k_search(
@ -288,13 +345,16 @@ impl VectorDatabase {
query_embedding: &Vec<f32>, query_embedding: &Vec<f32>,
limit: usize, limit: usize,
file_ids: &[i64], file_ids: &[i64],
) -> Result<Vec<(i64, f32)>> { ) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
let query_embedding = query_embedding.clone();
let file_ids = file_ids.to_vec();
self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
self.for_each_document(file_ids, |id, embedding| { Self::for_each_document(db, &file_ids, |id, embedding| {
let similarity = dot(&embedding, &query_embedding); let similarity = dot(&embedding, &query_embedding);
let ix = match results let ix = match results.binary_search_by(|(_, s)| {
.binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)) similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
{ }) {
Ok(ix) => ix, Ok(ix) => ix,
Err(ix) => ix, Err(ix) => ix,
}; };
@ -302,7 +362,8 @@ impl VectorDatabase {
results.truncate(limit); results.truncate(limit);
})?; })?;
Ok(results) anyhow::Ok(results)
})
} }
pub fn retrieve_included_file_ids( pub fn retrieve_included_file_ids(
@ -310,8 +371,12 @@ impl VectorDatabase {
worktree_ids: &[i64], worktree_ids: &[i64],
includes: &[PathMatcher], includes: &[PathMatcher],
excludes: &[PathMatcher], excludes: &[PathMatcher],
) -> Result<Vec<i64>> { ) -> impl Future<Output = Result<Vec<i64>>> {
let mut file_query = self.db.prepare( let worktree_ids = worktree_ids.to_vec();
let includes = includes.to_vec();
let excludes = excludes.to_vec();
self.transact(move |db| {
let mut file_query = db.prepare(
" "
SELECT SELECT
id, relative_path id, relative_path
@ -323,7 +388,7 @@ impl VectorDatabase {
)?; )?;
let mut file_ids = Vec::<i64>::new(); let mut file_ids = Vec::<i64>::new();
let mut rows = file_query.query([ids_to_sql(worktree_ids)])?; let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
while let Some(row) = rows.next()? { while let Some(row) = rows.next()? {
let file_id = row.get(0)?; let file_id = row.get(0)?;
@ -336,11 +401,16 @@ impl VectorDatabase {
} }
} }
Ok(file_ids) anyhow::Ok(file_ids)
})
} }
fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> { fn for_each_document(
let mut query_statement = self.db.prepare( db: &rusqlite::Connection,
file_ids: &[i64],
mut f: impl FnMut(i64, Vec<f32>),
) -> Result<()> {
let mut query_statement = db.prepare(
" "
SELECT SELECT
id, embedding id, embedding
@ -360,8 +430,13 @@ impl VectorDatabase {
Ok(()) Ok(())
} }
pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> { pub fn get_documents_by_ids(
let mut statement = self.db.prepare( &self,
ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
let ids = ids.to_vec();
self.transact(move |db| {
let mut statement = db.prepare(
" "
SELECT SELECT
documents.id, documents.id,
@ -377,7 +452,7 @@ impl VectorDatabase {
", ",
)?; )?;
let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| { let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
Ok(( Ok((
row.get::<_, i64>(0)?, row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?, row.get::<_, i64>(1)?,
@ -393,7 +468,7 @@ impl VectorDatabase {
} }
let mut results = Vec::with_capacity(ids.len()); let mut results = Vec::with_capacity(ids.len());
for id in ids { for id in &ids {
let value = values_by_id let value = values_by_id
.remove(id) .remove(id)
.ok_or(anyhow!("missing document id {}", id))?; .ok_or(anyhow!("missing document id {}", id))?;
@ -401,6 +476,7 @@ impl VectorDatabase {
} }
Ok(results) Ok(results)
})
} }
} }

View file

@ -12,11 +12,10 @@ use anyhow::{anyhow, Result};
use db::VectorDatabase; use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use embedding_queue::{EmbeddingQueue, FileToEmbed}; use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{channel::oneshot, Future};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Anchor, Buffer, Language, LanguageRegistry}; use language::{Anchor, Buffer, Language, LanguageRegistry};
use parking_lot::Mutex; use parking_lot::Mutex;
use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES}; use parsing::{CodeContextRetriever, PARSEABLE_ENTIRE_FILE_TYPES};
use postage::watch; use postage::watch;
use project::{ use project::{
search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId, search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId,
@ -101,13 +100,11 @@ pub fn init(
pub struct SemanticIndex { pub struct SemanticIndex {
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
database_url: Arc<PathBuf>, db: VectorDatabase,
embedding_provider: Arc<dyn EmbeddingProvider>, embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
db_update_tx: channel::Sender<DbOperation>,
parsing_files_tx: channel::Sender<PendingFile>, parsing_files_tx: channel::Sender<PendingFile>,
_embedding_task: Task<()>, _embedding_task: Task<()>,
_db_update_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>, _parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>, projects: HashMap<WeakModelHandle<Project>, ProjectState>,
} }
@ -203,32 +200,6 @@ pub struct SearchResult {
pub range: Range<Anchor>, pub range: Range<Anchor>,
} }
enum DbOperation {
InsertFile {
worktree_id: i64,
documents: Vec<Document>,
path: PathBuf,
mtime: SystemTime,
job_handle: JobHandle,
},
Delete {
worktree_id: i64,
path: PathBuf,
},
FindOrCreateWorktree {
path: PathBuf,
sender: oneshot::Sender<Result<i64>>,
},
FileMTimes {
worktree_id: i64,
sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
},
WorktreePreviouslyIndexed {
path: Arc<Path>,
sender: oneshot::Sender<Result<bool>>,
},
}
impl SemanticIndex { impl SemanticIndex {
pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> { pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() { if cx.has_global::<ModelHandle<Self>>() {
@ -245,18 +216,14 @@ impl SemanticIndex {
async fn new( async fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
database_url: PathBuf, database_path: PathBuf,
embedding_provider: Arc<dyn EmbeddingProvider>, embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> { ) -> Result<ModelHandle<Self>> {
let t0 = Instant::now(); let t0 = Instant::now();
let database_url = Arc::new(database_url); let database_path = Arc::from(database_path);
let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
let db = cx
.background()
.spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
.await?;
log::trace!( log::trace!(
"db initialization took {:?} milliseconds", "db initialization took {:?} milliseconds",
@ -265,32 +232,16 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| { Ok(cx.add_model(|cx| {
let t0 = Instant::now(); let t0 = Instant::now();
// Perform database operations
let (db_update_tx, db_update_rx) = channel::unbounded();
let _db_update_task = cx.background().spawn({
async move {
while let Ok(job) = db_update_rx.recv().await {
Self::run_db_operation(&db, job)
}
}
});
let embedding_queue = let embedding_queue =
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({ let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files(); let embedded_files = embedding_queue.finished_files();
let db_update_tx = db_update_tx.clone(); let db = db.clone();
async move { async move {
while let Ok(file) = embedded_files.recv().await { while let Ok(file) = embedded_files.recv().await {
db_update_tx db.insert_file(file.worktree_id, file.path, file.mtime, file.documents)
.try_send(DbOperation::InsertFile { .await
worktree_id: file.worktree_id, .log_err();
documents: file.documents,
path: file.path,
mtime: file.mtime,
job_handle: file.job_handle,
})
.ok();
} }
} }
}); });
@ -325,12 +276,10 @@ impl SemanticIndex {
); );
Self { Self {
fs, fs,
database_url, db,
embedding_provider, embedding_provider,
language_registry, language_registry,
db_update_tx,
parsing_files_tx, parsing_files_tx,
_db_update_task,
_embedding_task, _embedding_task,
_parsing_files_tasks, _parsing_files_tasks,
projects: HashMap::new(), projects: HashMap::new(),
@ -338,40 +287,6 @@ impl SemanticIndex {
})) }))
} }
fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
match job {
DbOperation::InsertFile {
worktree_id,
documents,
path,
mtime,
job_handle,
} => {
db.insert_file(worktree_id, path, mtime, documents)
.log_err();
drop(job_handle)
}
DbOperation::Delete { worktree_id, path } => {
db.delete_file(worktree_id, path).log_err();
}
DbOperation::FindOrCreateWorktree { path, sender } => {
let id = db.find_or_create_worktree(&path);
sender.send(id).ok();
}
DbOperation::FileMTimes {
worktree_id: worktree_db_id,
sender,
} => {
let file_mtimes = db.get_file_mtimes(worktree_db_id);
sender.send(file_mtimes).ok();
}
DbOperation::WorktreePreviouslyIndexed { path, sender } => {
let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
sender.send(worktree_indexed).ok();
}
}
}
async fn parse_file( async fn parse_file(
fs: &Arc<dyn Fs>, fs: &Arc<dyn Fs>,
pending_file: PendingFile, pending_file: PendingFile,
@ -409,36 +324,6 @@ impl SemanticIndex {
} }
} }
fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
.try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
.unwrap();
async move { rx.await? }
}
fn get_file_mtimes(
&self,
worktree_id: i64,
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
.try_send(DbOperation::FileMTimes {
worktree_id,
sender: tx,
})
.unwrap();
async move { rx.await? }
}
fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
.try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
.unwrap();
async move { rx.await? }
}
pub fn project_previously_indexed( pub fn project_previously_indexed(
&mut self, &mut self,
project: ModelHandle<Project>, project: ModelHandle<Project>,
@ -447,7 +332,10 @@ impl SemanticIndex {
let worktrees_indexed_previously = project let worktrees_indexed_previously = project
.read(cx) .read(cx)
.worktrees(cx) .worktrees(cx)
.map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path())) .map(|worktree| {
self.db
.worktree_previously_indexed(&worktree.read(cx).abs_path())
})
.collect::<Vec<_>>(); .collect::<Vec<_>>();
cx.spawn(|_, _cx| async move { cx.spawn(|_, _cx| async move {
let worktree_indexed_previously = let worktree_indexed_previously =
@ -528,7 +416,8 @@ impl SemanticIndex {
.read(cx) .read(cx)
.worktrees(cx) .worktrees(cx)
.map(|worktree| { .map(|worktree| {
self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf()) self.db
.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -559,7 +448,7 @@ impl SemanticIndex {
db_ids_by_worktree_id.insert(worktree.id(), db_id); db_ids_by_worktree_id.insert(worktree.id(), db_id);
worktree_file_mtimes.insert( worktree_file_mtimes.insert(
worktree.id(), worktree.id(),
this.read_with(&cx, |this, _| this.get_file_mtimes(db_id)) this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id))
.await?, .await?,
); );
} }
@ -704,11 +593,12 @@ impl SemanticIndex {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let embedding_provider = self.embedding_provider.clone(); let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone(); let db_path = self.db.path().clone();
let fs = self.fs.clone(); let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let t0 = Instant::now(); let t0 = Instant::now();
let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?; let database =
VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
let phrase_embedding = embedding_provider let phrase_embedding = embedding_provider
.embed_batch(vec![phrase]) .embed_batch(vec![phrase])
@ -722,8 +612,9 @@ impl SemanticIndex {
t0.elapsed().as_millis() t0.elapsed().as_millis()
); );
let file_ids = let file_ids = database
database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?; .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
.await?;
let batch_n = cx.background().num_cpus(); let batch_n = cx.background().num_cpus();
let ids_len = file_ids.clone().len(); let ids_len = file_ids.clone().len();
@ -733,27 +624,24 @@ impl SemanticIndex {
ids_len / batch_n ids_len / batch_n
}; };
let mut result_tasks = Vec::new(); let mut batch_results = Vec::new();
for batch in file_ids.chunks(batch_size) { for batch in file_ids.chunks(batch_size) {
let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>(); let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
let limit = limit.clone(); let limit = limit.clone();
let fs = fs.clone(); let fs = fs.clone();
let database_url = database_url.clone(); let db_path = db_path.clone();
let phrase_embedding = phrase_embedding.clone(); let phrase_embedding = phrase_embedding.clone();
let task = cx.background().spawn(async move { if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
let database = VectorDatabase::new(fs, database_url).await.log_err(); .await
if database.is_none() { .log_err()
return Err(anyhow!("failed to acquire database connection")); {
} else { batch_results.push(async move {
database db.top_k_search(&phrase_embedding, limit, batch.as_slice())
.unwrap() .await
.top_k_search(&phrase_embedding, limit, batch.as_slice())
}
}); });
result_tasks.push(task);
} }
}
let batch_results = futures::future::join_all(result_tasks).await; let batch_results = futures::future::join_all(batch_results).await;
let mut results = Vec::new(); let mut results = Vec::new();
for batch_result in batch_results { for batch_result in batch_results {
@ -772,7 +660,7 @@ impl SemanticIndex {
} }
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>(); let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
let documents = database.get_documents_by_ids(ids.as_slice())?; let documents = database.get_documents_by_ids(ids.as_slice()).await?;
let mut tasks = Vec::new(); let mut tasks = Vec::new();
let mut ranges = Vec::new(); let mut ranges = Vec::new();
@ -822,7 +710,8 @@ impl SemanticIndex {
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) { ) {
let mut pending_files = Vec::new(); let mut pending_files = Vec::new();
let (language_registry, parsing_files_tx) = this.update(cx, |this, cx| { let mut files_to_delete = Vec::new();
let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| {
if let Some(project_state) = this.projects.get_mut(&project.downgrade()) { if let Some(project_state) = this.projects.get_mut(&project.downgrade()) {
let outstanding_job_count_tx = &project_state.outstanding_job_count_tx; let outstanding_job_count_tx = &project_state.outstanding_job_count_tx;
let db_ids = &project_state.worktree_db_ids; let db_ids = &project_state.worktree_db_ids;
@ -853,12 +742,7 @@ impl SemanticIndex {
}; };
if info.is_deleted { if info.is_deleted {
this.db_update_tx files_to_delete.push((worktree_db_id, path.path.to_path_buf()));
.try_send(DbOperation::Delete {
worktree_id: worktree_db_id,
path: path.path.to_path_buf(),
})
.ok();
} else { } else {
let absolute_path = worktree.read(cx).absolutize(&path.path); let absolute_path = worktree.read(cx).absolutize(&path.path);
let job_handle = JobHandle::new(&outstanding_job_count_tx); let job_handle = JobHandle::new(&outstanding_job_count_tx);
@ -877,11 +761,16 @@ impl SemanticIndex {
} }
( (
this.db.clone(),
this.language_registry.clone(), this.language_registry.clone(),
this.parsing_files_tx.clone(), this.parsing_files_tx.clone(),
) )
}); });
for (worktree_db_id, path) in files_to_delete {
db.delete_file(worktree_db_id, path).await.log_err();
}
for mut pending_file in pending_files { for mut pending_file in pending_files {
if let Ok(language) = language_registry if let Ok(language) = language_registry
.language_for_file(&pending_file.relative_path, None) .language_for_file(&pending_file.relative_path, None)