Updated database calls to share single connection, and simplified top_k_search sorting.

Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
KCaverly 2023-06-26 14:57:57 -04:00
parent 0f232e0ce2
commit 74b693d6b9
4 changed files with 148 additions and 124 deletions

View file

@ -1,4 +1,4 @@
use std::collections::HashMap; use std::{collections::HashMap, path::PathBuf};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
@ -46,31 +46,50 @@ impl FromSql for Embedding {
} }
} }
pub struct VectorDatabase {} pub struct VectorDatabase {
db: rusqlite::Connection,
}
impl VectorDatabase { impl VectorDatabase {
pub async fn initialize_database() -> Result<()> { pub fn new() -> Result<Self> {
let this = Self {
db: rusqlite::Connection::open(VECTOR_DB_URL)?,
};
this.initialize_database()?;
Ok(this)
}
fn initialize_database(&self) -> Result<()> {
// This will create the database if it doesnt exist // This will create the database if it doesnt exist
let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
// Initialize Vector Databasing Tables // Initialize Vector Databasing Tables
db.execute( // self.db.execute(
// "
// CREATE TABLE IF NOT EXISTS projects (
// id INTEGER PRIMARY KEY AUTOINCREMENT,
// path NVARCHAR(100) NOT NULL
// )
// ",
// [],
// )?;
self.db.execute(
"CREATE TABLE IF NOT EXISTS files ( "CREATE TABLE IF NOT EXISTS files (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
path NVARCHAR(100) NOT NULL, path NVARCHAR(100) NOT NULL,
sha1 NVARCHAR(40) NOT NULL sha1 NVARCHAR(40) NOT NULL
)", )",
[], [],
)?; )?;
db.execute( self.db.execute(
"CREATE TABLE IF NOT EXISTS documents ( "CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL, file_id INTEGER NOT NULL,
offset INTEGER NOT NULL, offset INTEGER NOT NULL,
name NVARCHAR(100) NOT NULL, name NVARCHAR(100) NOT NULL,
embedding BLOB NOT NULL, embedding BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)", )",
[], [],
)?; )?;
@ -78,23 +97,37 @@ impl VectorDatabase {
Ok(()) Ok(())
} }
pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> { // pub async fn get_or_create_project(project_path: PathBuf) -> Result<usize> {
// Write to files table, and return generated id. // // Check if we have the project, if we do, return the ID
let db = rusqlite::Connection::open(VECTOR_DB_URL)?; // // If we do not have the project, insert the project and return the ID
let files_insert = db.execute( // let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
// let projects_query = db.prepare(&format!(
// "SELECT id FROM projects WHERE path = {}",
// project_path.to_str().unwrap() // This is unsafe
// ))?;
// let project_id = db.last_insert_rowid();
// return Ok(project_id as usize);
// }
pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
// Write to files table, and return generated id.
let files_insert = self.db.execute(
"INSERT INTO files (path, sha1) VALUES (?1, ?2)", "INSERT INTO files (path, sha1) VALUES (?1, ?2)",
params![indexed_file.path.to_str(), indexed_file.sha1], params![indexed_file.path.to_str(), indexed_file.sha1],
)?; )?;
let inserted_id = db.last_insert_rowid(); let inserted_id = self.db.last_insert_rowid();
// Currently inserting at approximately 3400 documents a second // Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind. // I imagine we can speed this up with a bulk insert of some kind.
for document in indexed_file.documents { for document in indexed_file.documents {
let embedding_blob = bincode::serialize(&document.embedding)?; let embedding_blob = bincode::serialize(&document.embedding)?;
db.execute( self.db.execute(
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
params![ params![
inserted_id, inserted_id,
@ -109,70 +142,42 @@ impl VectorDatabase {
} }
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> { pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
let db = rusqlite::Connection::open(VECTOR_DB_URL)?; let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?;
let result_iter = query_statement.query_map([], |row| {
fn query(db: Connection) -> rusqlite::Result<Vec<FileRecord>> { Ok(FileRecord {
let mut query_statement = db.prepare("SELECT id, path, sha1 FROM files")?; id: row.get(0)?,
let result_iter = query_statement.query_map([], |row| { path: row.get(1)?,
Ok(FileRecord { sha1: row.get(2)?,
id: row.get(0)?, })
path: row.get(1)?, })?;
sha1: row.get(2)?,
})
})?;
let mut results = vec![];
for result in result_iter {
results.push(result?);
}
return Ok(results);
}
let mut pages: HashMap<usize, FileRecord> = HashMap::new(); let mut pages: HashMap<usize, FileRecord> = HashMap::new();
let result_iter = query(db); for result in result_iter {
if result_iter.is_ok() { let result = result?;
for result in result_iter.unwrap() { pages.insert(result.id, result);
pages.insert(result.id, result);
}
} }
return Ok(pages); Ok(pages)
} }
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> { pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
// Should return a HashMap in which the key is the id, and the value is the finished document let mut query_statement = self
.db
// Get Data from Database .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
let db = rusqlite::Connection::open(VECTOR_DB_URL)?; let result_iter = query_statement.query_map([], |row| {
Ok(DocumentRecord {
fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> { id: row.get(0)?,
let mut query_statement = file_id: row.get(1)?,
db.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?; offset: row.get(2)?,
let result_iter = query_statement.query_map([], |row| { name: row.get(3)?,
Ok(DocumentRecord { embedding: row.get(4)?,
id: row.get(0)?, })
file_id: row.get(1)?, })?;
offset: row.get(2)?,
name: row.get(3)?,
embedding: row.get(4)?,
})
})?;
let mut results = vec![];
for result in result_iter {
results.push(result?);
}
return Ok(results);
}
let mut documents: HashMap<usize, DocumentRecord> = HashMap::new(); let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
let result_iter = query(db); for result in result_iter {
if result_iter.is_ok() { let result = result?;
for result in result_iter.unwrap() { documents.insert(result.id, result);
documents.insert(result.id, result);
}
} }
return Ok(documents); return Ok(documents);

View file

@ -94,16 +94,6 @@ impl EmbeddingProvider for OpenAIEmbeddings {
response.usage.total_tokens response.usage.total_tokens
); );
// do we need to re-order these based on the `index` field?
eprintln!(
"indices: {:?}",
response
.data
.iter()
.map(|embedding| embedding.index)
.collect::<Vec<_>>()
);
Ok(response Ok(response
.data .data
.into_iter() .into_iter()

View file

@ -19,8 +19,8 @@ pub struct BruteForceSearch {
} }
impl BruteForceSearch { impl BruteForceSearch {
pub fn load() -> Result<Self> { pub fn load(db: &VectorDatabase) -> Result<Self> {
let db = VectorDatabase {}; // let db = VectorDatabase {};
let documents = db.get_documents()?; let documents = db.get_documents()?;
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
let mut document_ids = vec![]; let mut document_ids = vec![];
@ -47,39 +47,36 @@ impl VectorSearch for BruteForceSearch {
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> { async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
let target = Array1::from_vec(vec.to_owned()); let target = Array1::from_vec(vec.to_owned());
let distances = self.candidate_array.dot(&target); let similarities = self.candidate_array.dot(&target);
let distances = distances.to_vec(); let similarities = similarities.to_vec();
// construct a tuple vector from the floats, the tuple being (index,float) // construct a tuple vector from the floats, the tuple being (index,float)
let mut with_indices = distances let mut with_indices = similarities
.clone() .iter()
.into_iter() .copied()
.enumerate() .enumerate()
.map(|(index, value)| (index, value)) .map(|(index, value)| (self.document_ids[index], value))
.collect::<Vec<(usize, f32)>>(); .collect::<Vec<(usize, f32)>>();
// sort the tuple vector by float // sort the tuple vector by float
with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) { with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
(true, true) => Ordering::Equal, with_indices.truncate(limit);
(true, false) => Ordering::Greater, with_indices
(false, true) => Ordering::Less,
(false, false) => a.1.partial_cmp(&b.1).unwrap(),
});
// extract the sorted indices from the sorted tuple vector // // extract the sorted indices from the sorted tuple vector
let stored_indices = with_indices // let stored_indices = with_indices
.into_iter() // .into_iter()
.map(|(index, value)| index) // .map(|(index, value)| index)
.collect::<Vec<usize>>(); // .collect::<Vec<>>();
let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect(); // let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
let mut results = vec![]; // let mut results = vec![];
for idx in sorted_indices[0..limit].to_vec() { // for idx in sorted_indices[0..limit].to_vec() {
results.push((self.document_ids[idx], 1.0 - distances[idx])); // results.push((self.document_ids[idx], 1.0 - similarities[idx]));
} // }
return results; // return results;
} }
} }

View file

@ -1,5 +1,6 @@
mod db; mod db;
mod embedding; mod embedding;
mod parsing;
mod search; mod search;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
@ -7,11 +8,13 @@ use db::VectorDatabase;
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use gpui::{AppContext, Entity, ModelContext, ModelHandle};
use language::LanguageRegistry; use language::LanguageRegistry;
use parsing::Document;
use project::{Fs, Project}; use project::{Fs, Project};
use search::{BruteForceSearch, VectorSearch};
use smol::channel; use smol::channel;
use std::{path::PathBuf, sync::Arc, time::Instant}; use std::{path::PathBuf, sync::Arc, time::Instant};
use tree_sitter::{Parser, QueryCursor}; use tree_sitter::{Parser, QueryCursor};
use util::{http::HttpClient, ResultExt}; use util::{http::HttpClient, ResultExt, TryFutureExt};
use workspace::WorkspaceCreated; use workspace::WorkspaceCreated;
pub fn init( pub fn init(
@ -39,13 +42,6 @@ pub fn init(
.detach(); .detach();
} }
#[derive(Debug)]
pub struct Document {
pub offset: usize,
pub name: String,
pub embedding: Vec<f32>,
}
#[derive(Debug)] #[derive(Debug)]
pub struct IndexedFile { pub struct IndexedFile {
path: PathBuf, path: PathBuf,
@ -180,18 +176,54 @@ impl VectorStore {
.detach(); .detach();
cx.background() cx.background()
.spawn(async move { .spawn({
let client = client.clone();
async move {
// Initialize Database, creates database and tables if not exists // Initialize Database, creates database and tables if not exists
VectorDatabase::initialize_database().await.log_err(); let db = VectorDatabase::new()?;
while let Ok(indexed_file) = indexed_files_rx.recv().await { while let Ok(indexed_file) = indexed_files_rx.recv().await {
VectorDatabase::insert_file(indexed_file).await.log_err(); db.insert_file(indexed_file).log_err();
}
// ALL OF THE BELOW IS FOR TESTING,
// This should be removed as we find and appropriate place for evaluate our search.
let embedding_provider = OpenAIEmbeddings{ client };
let queries = vec![
"compute embeddings for all of the symbols in the codebase, and write them to a database",
"compute an outline view of all of the symbols in a buffer",
"scan a directory on the file system and load all of its children into an in-memory snapshot",
];
let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
let t2 = Instant::now();
let documents = db.get_documents().unwrap();
let files = db.get_files().unwrap();
println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
let t1 = Instant::now();
let mut bfs = BruteForceSearch::load(&db).unwrap();
println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
for (idx, embed) in embeddings.into_iter().enumerate() {
let t0 = Instant::now();
println!("\nQuery: {:?}", queries[idx]);
let results = bfs.top_k_search(&embed, 5).await;
println!("Search Elapsed: {}", t0.elapsed().as_millis());
for (id, distance) in results {
println!("");
println!(" distance: {:?}", distance);
println!(" document: {:?}", documents[&id].name);
println!(" path: {:?}", files[&documents[&id].file_id].path);
}
} }
anyhow::Ok(()) anyhow::Ok(())
}) }}.log_err())
.detach(); .detach();
let provider = DummyEmbeddings {}; let provider = DummyEmbeddings {};
// let provider = OpenAIEmbeddings { client };
cx.background() cx.background()
.scoped(|scope| { .scoped(|scope| {