diff --git a/Cargo.lock b/Cargo.lock index ff4caaa5a6..1ea1d1a1b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7967,6 +7967,7 @@ dependencies = [ "serde_json", "sha-1 0.10.1", "smol", + "tempdir", "tree-sitter", "tree-sitter-rust", "unindent", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index dbe0a2e69c..edc06bb295 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -17,7 +17,7 @@ util = { path = "../util" } anyhow.workspace = true futures.workspace = true smol.workspace = true -rusqlite = { version = "0.27.0", features=["blob"] } +rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] } isahc.workspace = true log.workspace = true tree-sitter.workspace = true @@ -38,3 +38,4 @@ workspace = { path = "../workspace", features = ["test-support"] } tree-sitter-rust = "*" rand.workspace = true unindent.workspace = true +tempdir.workspace = true diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index bcb1090a8d..f074a7066b 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -7,9 +7,10 @@ use anyhow::{anyhow, Result}; use rusqlite::{ params, - types::{FromSql, FromSqlResult, ValueRef}, - Connection, + types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, + ToSql, }; +use sha1::{Digest, Sha1}; use crate::IndexedFile; @@ -32,7 +33,60 @@ pub struct DocumentRecord { pub struct FileRecord { pub id: usize, pub relative_path: String, - pub sha1: String, + pub sha1: FileSha1, +} + +#[derive(Debug)] +pub struct FileSha1(pub Vec); + +impl FileSha1 { + pub fn from_str(content: String) -> Self { + let mut hasher = Sha1::new(); + hasher.update(content); + let sha1 = hasher.finalize()[..] + .into_iter() + .map(|val| val.to_owned()) + .collect::>(); + return FileSha1(sha1); + } + + pub fn equals(&self, content: &String) -> bool { + let mut hasher = Sha1::new(); + hasher.update(content); + let sha1 = hasher.finalize()[..] + .into_iter() + .map(|val| val.to_owned()) + .collect::>(); + + let equal = self + .0 + .clone() + .into_iter() + .zip(sha1) + .filter(|&(a, b)| a == b) + .count() + == self.0.len(); + + equal + } +} + +impl ToSql for FileSha1 { + fn to_sql(&self) -> rusqlite::Result> { + return self.0.to_sql(); + } +} + +impl FromSql for FileSha1 { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + Ok(FileSha1( + bytes + .into_iter() + .map(|val| val.to_owned()) + .collect::>(), + )) + } } #[derive(Debug)] @@ -63,6 +117,8 @@ impl VectorDatabase { } fn initialize_database(&self) -> Result<()> { + rusqlite::vtab::array::load_module(&self.db)?; + // This will create the database if it doesnt exist // Initialize Vector Databasing Tables @@ -81,7 +137,7 @@ impl VectorDatabase { id INTEGER PRIMARY KEY AUTOINCREMENT, worktree_id INTEGER NOT NULL, relative_path VARCHAR NOT NULL, - sha1 NVARCHAR(40) NOT NULL, + sha1 BLOB NOT NULL, FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE )", [], @@ -102,30 +158,23 @@ impl VectorDatabase { Ok(()) } - // pub async fn get_or_create_project(project_path: PathBuf) -> Result { - // // Check if we have the project, if we do, return the ID - // // If we do not have the project, insert the project and return the ID - - // let db = rusqlite::Connection::open(VECTOR_DB_URL)?; - - // let projects_query = db.prepare(&format!( - // "SELECT id FROM projects WHERE path = {}", - // project_path.to_str().unwrap() // This is unsafe - // ))?; - - // let project_id = db.last_insert_rowid(); - - // return Ok(project_id as usize); - // } - - pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> { + pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> { // Write to files table, and return generated id. - let files_insert = self.db.execute( - "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)", - params![indexed_file.path.to_str(), indexed_file.sha1], + log::info!("Inserting File!"); + self.db.execute( + " + DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2; + ", + params![worktree_id, indexed_file.path.to_str()], + )?; + self.db.execute( + " + INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3); + ", + params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1], )?; - let inserted_id = self.db.last_insert_rowid(); + let file_id = self.db.last_insert_rowid(); // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. @@ -135,7 +184,7 @@ impl VectorDatabase { self.db.execute( "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", params![ - inserted_id, + file_id, document.offset.to_string(), document.name, embedding_blob @@ -147,25 +196,41 @@ impl VectorDatabase { } pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result { + // Check that the absolute path doesnt exist + let mut worktree_query = self + .db + .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; + + let worktree_id = worktree_query + .query_row(params![worktree_root_path.to_string_lossy()], |row| { + Ok(row.get::<_, i64>(0)?) + }) + .map_err(|err| anyhow!(err)); + + if worktree_id.is_ok() { + return worktree_id; + } + + // If worktree_id is Err, insert new worktree self.db.execute( " INSERT into worktrees (absolute_path) VALUES (?1) - ON CONFLICT DO NOTHING ", params![worktree_root_path.to_string_lossy()], )?; Ok(self.db.last_insert_rowid()) } - pub fn get_file_hashes(&self, worktree_id: i64) -> Result> { - let mut statement = self - .db - .prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?; - let mut result = Vec::new(); - for row in - statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))? - { - result.push(row?); + pub fn get_file_hashes(&self, worktree_id: i64) -> Result> { + let mut statement = self.db.prepare( + "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path", + )?; + let mut result: HashMap = HashMap::new(); + for row in statement.query_map(params![worktree_id], |row| { + Ok((row.get::<_, String>(0)?.into(), row.get(1)?)) + })? { + let row = row?; + result.insert(row.0, row.1); } Ok(result) } @@ -204,6 +269,53 @@ impl VectorDatabase { Ok(()) } + pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result> { + let mut statement = self.db.prepare( + " + SELECT + documents.id, files.relative_path, documents.offset, documents.name + FROM + documents, files + WHERE + documents.file_id = files.id AND + documents.id in rarray(?) + ", + )?; + + let result_iter = statement.query_map( + params![std::rc::Rc::new( + ids.iter() + .copied() + .map(|v| rusqlite::types::Value::from(v)) + .collect::>() + )], + |row| { + Ok(( + row.get::<_, i64>(0)?, + row.get::<_, String>(1)?.into(), + row.get(2)?, + row.get(3)?, + )) + }, + )?; + + let mut values_by_id = HashMap::::default(); + for row in result_iter { + let (id, path, offset, name) = row?; + values_by_id.insert(id, (path, offset, name)); + } + + let mut results = Vec::with_capacity(ids.len()); + for id in ids { + let (path, offset, name) = values_by_id + .remove(id) + .ok_or(anyhow!("missing document id {}", id))?; + results.push((path, offset, name)); + } + + Ok(results) + } + pub fn get_documents(&self) -> Result> { let mut query_statement = self .db diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index f34316e950..7e4c29cef6 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -7,15 +7,14 @@ mod search; mod vector_store_tests; use anyhow::{anyhow, Result}; -use db::{VectorDatabase, VECTOR_DB_URL}; -use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; +use db::{FileSha1, VectorDatabase, VECTOR_DB_URL}; +use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task}; -use language::LanguageRegistry; +use language::{Language, LanguageRegistry}; use parsing::Document; use project::{Fs, Project}; -use search::{BruteForceSearch, VectorSearch}; use smol::channel; -use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant}; +use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc}; use tree_sitter::{Parser, QueryCursor}; use util::{http::HttpClient, ResultExt, TryFutureExt}; use workspace::WorkspaceCreated; @@ -45,7 +44,7 @@ pub fn init( let project = workspace.read(cx).project().clone(); if project.read(cx).is_local() { vector_store.update(cx, |store, cx| { - store.add_project(project, cx); + store.add_project(project, cx).detach(); }); } } @@ -57,16 +56,10 @@ pub fn init( #[derive(Debug)] pub struct IndexedFile { path: PathBuf, - sha1: String, + sha1: FileSha1, documents: Vec, } -// struct SearchResult { -// path: PathBuf, -// offset: usize, -// name: String, -// distance: f32, -// } struct VectorStore { fs: Arc, database_url: Arc, @@ -99,20 +92,10 @@ impl VectorStore { cursor: &mut QueryCursor, parser: &mut Parser, embedding_provider: &dyn EmbeddingProvider, - language_registry: &Arc, + language: Arc, file_path: PathBuf, content: String, ) -> Result { - dbg!(&file_path, &content); - - let language = language_registry - .language_for_file(&file_path, None) - .await?; - - if language.name().as_ref() != "Rust" { - Err(anyhow!("unsupported language"))?; - } - let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; let outline_config = grammar .outline_config @@ -156,9 +139,11 @@ impl VectorStore { document.embedding = embedding; } + let sha1 = FileSha1::from_str(content); + return Ok(IndexedFile { path: file_path, - sha1: String::new(), + sha1, documents, }); } @@ -171,7 +156,13 @@ impl VectorStore { let worktree_scans_complete = project .read(cx) .worktrees(cx) - .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete()) + .map(|worktree| { + let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete(); + async move { + scan_complete.await; + log::info!("worktree scan completed"); + } + }) .collect::>(); let fs = self.fs.clone(); @@ -182,6 +173,13 @@ impl VectorStore { cx.spawn(|_, cx| async move { futures::future::join_all(worktree_scans_complete).await; + // TODO: remove this after fixing the bug in scan_complete + cx.background() + .timer(std::time::Duration::from_secs(3)) + .await; + + let db = VectorDatabase::new(&database_url)?; + let worktrees = project.read_with(&cx, |project, cx| { project .worktrees(cx) @@ -189,37 +187,74 @@ impl VectorStore { .collect::>() }); - let db = VectorDatabase::new(&database_url)?; let worktree_root_paths = worktrees .iter() .map(|worktree| worktree.abs_path().clone()) .collect::>(); - let (db, file_hashes) = cx + + // Here we query the worktree ids, and yet we dont have them elsewhere + // We likely want to clean up these datastructures + let (db, worktree_hashes, worktree_ids) = cx .background() .spawn(async move { - let mut hashes = Vec::new(); + let mut worktree_ids: HashMap = HashMap::new(); + let mut hashes: HashMap> = HashMap::new(); for worktree_root_path in worktree_root_paths { let worktree_id = db.find_or_create_worktree(worktree_root_path.as_ref())?; - hashes.push((worktree_id, db.get_file_hashes(worktree_id)?)); + worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id); + hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?); } - anyhow::Ok((db, hashes)) + anyhow::Ok((db, hashes, worktree_ids)) }) .await?; - let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>(); - let (indexed_files_tx, indexed_files_rx) = channel::unbounded::(); + let (paths_tx, paths_rx) = + channel::unbounded::<(i64, PathBuf, String, Arc)>(); + let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>(); cx.background() .spawn({ let fs = fs.clone(); async move { for worktree in worktrees.into_iter() { + let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()]; + let file_hashes = &worktree_hashes[&worktree_id]; for file in worktree.files(false, 0) { let absolute_path = worktree.absolutize(&file.path); - dbg!(&absolute_path); - if let Some(content) = fs.load(&absolute_path).await.log_err() { - dbg!(&content); - paths_tx.try_send((0, absolute_path, content)).unwrap(); + + if let Ok(language) = language_registry + .language_for_file(&absolute_path, None) + .await + { + if language.name().as_ref() != "Rust" { + continue; + } + + if let Some(content) = fs.load(&absolute_path).await.log_err() { + log::info!("loaded file: {absolute_path:?}"); + + let path_buf = file.path.to_path_buf(); + let already_stored = file_hashes + .get(&path_buf) + .map_or(false, |existing_hash| { + existing_hash.equals(&content) + }); + + if !already_stored { + log::info!( + "File Changed (Sending to Parse): {:?}", + &path_buf + ); + paths_tx + .try_send(( + worktree_id, + path_buf, + content, + language, + )) + .unwrap(); + } + } } } } @@ -230,8 +265,8 @@ impl VectorStore { let db_write_task = cx.background().spawn( async move { // Initialize Database, creates database and tables if not exists - while let Ok(indexed_file) = indexed_files_rx.recv().await { - db.insert_file(indexed_file).log_err(); + while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await { + db.insert_file(worktree_id, indexed_file).log_err(); } // ALL OF THE BELOW IS FOR TESTING, @@ -271,29 +306,29 @@ impl VectorStore { .log_err(), ); - let provider = DummyEmbeddings {}; - // let provider = OpenAIEmbeddings { client }; - cx.background() .scoped(|scope| { for _ in 0..cx.background().num_cpus() { scope.spawn(async { let mut parser = Parser::new(); let mut cursor = QueryCursor::new(); - while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await + while let Ok((worktree_id, file_path, content, language)) = + paths_rx.recv().await { if let Some(indexed_file) = Self::index_file( &mut cursor, &mut parser, - &provider, - &language_registry, + embedding_provider.as_ref(), + language, file_path, content, ) .await .log_err() { - indexed_files_tx.try_send(indexed_file).unwrap(); + indexed_files_tx + .try_send((worktree_id, indexed_file)) + .unwrap(); } } }); @@ -315,41 +350,42 @@ impl VectorStore { ) -> Task>> { let embedding_provider = self.embedding_provider.clone(); let database_url = self.database_url.clone(); - cx.spawn(|this, cx| async move { + cx.background().spawn(async move { let database = VectorDatabase::new(database_url.as_ref())?; - // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?; - // + let phrase_embedding = embedding_provider + .embed_batch(vec![&phrase]) + .await? + .into_iter() + .next() + .unwrap(); + let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1); - database.for_each_document(0, |id, embedding| { - dbg!(id, &embedding); - - let similarity = dot(&embedding.0, &embedding.0); + let similarity = dot(&embedding.0, &phrase_embedding); let ix = match results.binary_search_by(|(_, s)| { - s.partial_cmp(&similarity).unwrap_or(Ordering::Equal) + similarity.partial_cmp(&s).unwrap_or(Ordering::Equal) }) { Ok(ix) => ix, Err(ix) => ix, }; - results.insert(ix, (id, similarity)); results.truncate(limit); })?; - dbg!(&results); - let ids = results.into_iter().map(|(id, _)| id).collect::>(); - // let documents = database.get_documents_by_ids(ids)?; + let documents = database.get_documents_by_ids(&ids)?; - // let search_provider = cx - // .background() - // .spawn(async move { BruteForceSearch::load(&database) }) - // .await?; - - // let results = search_provider.top_k_search(&embedding, limit)) - - anyhow::Ok(vec![]) + anyhow::Ok( + documents + .into_iter() + .map(|(file_path, offset, name)| SearchResult { + name, + offset, + file_path, + }) + .collect(), + ) }) } } diff --git a/crates/vector_store/src/vector_store_tests.rs b/crates/vector_store/src/vector_store_tests.rs index f3d01835e9..c67bb9954f 100644 --- a/crates/vector_store/src/vector_store_tests.rs +++ b/crates/vector_store/src/vector_store_tests.rs @@ -57,20 +57,26 @@ async fn test_vector_store(cx: &mut TestAppContext) { ); languages.add(rust_language); + let db_dir = tempdir::TempDir::new("vector-store").unwrap(); + let db_path = db_dir.path().join("db.sqlite"); + let store = cx.add_model(|_| { VectorStore::new( fs.clone(), - "foo".to_string(), + db_path.to_string_lossy().to_string(), Arc::new(FakeEmbeddingProvider), languages, ) }); let project = Project::test(fs, ["/the-root".as_ref()], cx).await; - store - .update(cx, |store, cx| store.add_project(project, cx)) - .await - .unwrap(); + let add_project = store.update(cx, |store, cx| store.add_project(project, cx)); + + // TODO - remove + cx.foreground() + .advance_clock(std::time::Duration::from_secs(3)); + + add_project.await.unwrap(); let search_results = store .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx)) @@ -78,7 +84,7 @@ async fn test_vector_store(cx: &mut TestAppContext) { .unwrap(); assert_eq!(search_results[0].offset, 0); - assert_eq!(search_results[1].name, "aaa"); + assert_eq!(search_results[0].name, "aaa"); } #[test] @@ -114,9 +120,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider { Ok(spans .iter() .map(|span| { - let mut result = vec![0.0; 26]; + let mut result = vec![1.0; 26]; for letter in span.chars() { - if letter as u32 > 'a' as u32 { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { let ix = (letter as u32) - ('a' as u32); if ix < 26 { result[ix as usize] += 1.0;