Updated database calls to share single connection, and simplified top_k_search sorting.
Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
parent
0f232e0ce2
commit
74b693d6b9
4 changed files with 148 additions and 124 deletions
|
@ -1,4 +1,4 @@
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, path::PathBuf};
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
|
|
||||||
|
@ -46,31 +46,50 @@ impl FromSql for Embedding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct VectorDatabase {}
|
pub struct VectorDatabase {
|
||||||
|
db: rusqlite::Connection,
|
||||||
|
}
|
||||||
|
|
||||||
impl VectorDatabase {
|
impl VectorDatabase {
|
||||||
pub async fn initialize_database() -> Result<()> {
|
pub fn new() -> Result<Self> {
|
||||||
|
let this = Self {
|
||||||
|
db: rusqlite::Connection::open(VECTOR_DB_URL)?,
|
||||||
|
};
|
||||||
|
this.initialize_database()?;
|
||||||
|
Ok(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initialize_database(&self) -> Result<()> {
|
||||||
// This will create the database if it doesnt exist
|
// This will create the database if it doesnt exist
|
||||||
let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
|
|
||||||
|
|
||||||
// Initialize Vector Databasing Tables
|
// Initialize Vector Databasing Tables
|
||||||
db.execute(
|
// self.db.execute(
|
||||||
|
// "
|
||||||
|
// CREATE TABLE IF NOT EXISTS projects (
|
||||||
|
// id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
// path NVARCHAR(100) NOT NULL
|
||||||
|
// )
|
||||||
|
// ",
|
||||||
|
// [],
|
||||||
|
// )?;
|
||||||
|
|
||||||
|
self.db.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS files (
|
"CREATE TABLE IF NOT EXISTS files (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
path NVARCHAR(100) NOT NULL,
|
path NVARCHAR(100) NOT NULL,
|
||||||
sha1 NVARCHAR(40) NOT NULL
|
sha1 NVARCHAR(40) NOT NULL
|
||||||
)",
|
)",
|
||||||
[],
|
[],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
db.execute(
|
self.db.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS documents (
|
"CREATE TABLE IF NOT EXISTS documents (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
file_id INTEGER NOT NULL,
|
file_id INTEGER NOT NULL,
|
||||||
offset INTEGER NOT NULL,
|
offset INTEGER NOT NULL,
|
||||||
name NVARCHAR(100) NOT NULL,
|
name NVARCHAR(100) NOT NULL,
|
||||||
embedding BLOB NOT NULL,
|
embedding BLOB NOT NULL,
|
||||||
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
|
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||||
)",
|
)",
|
||||||
[],
|
[],
|
||||||
)?;
|
)?;
|
||||||
|
@ -78,23 +97,37 @@ impl VectorDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> {
|
// pub async fn get_or_create_project(project_path: PathBuf) -> Result<usize> {
|
||||||
// Write to files table, and return generated id.
|
// // Check if we have the project, if we do, return the ID
|
||||||
let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
|
// // If we do not have the project, insert the project and return the ID
|
||||||
|
|
||||||
let files_insert = db.execute(
|
// let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
|
||||||
|
|
||||||
|
// let projects_query = db.prepare(&format!(
|
||||||
|
// "SELECT id FROM projects WHERE path = {}",
|
||||||
|
// project_path.to_str().unwrap() // This is unsafe
|
||||||
|
// ))?;
|
||||||
|
|
||||||
|
// let project_id = db.last_insert_rowid();
|
||||||
|
|
||||||
|
// return Ok(project_id as usize);
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
|
||||||
|
// Write to files table, and return generated id.
|
||||||
|
let files_insert = self.db.execute(
|
||||||
"INSERT INTO files (path, sha1) VALUES (?1, ?2)",
|
"INSERT INTO files (path, sha1) VALUES (?1, ?2)",
|
||||||
params![indexed_file.path.to_str(), indexed_file.sha1],
|
params![indexed_file.path.to_str(), indexed_file.sha1],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let inserted_id = db.last_insert_rowid();
|
let inserted_id = self.db.last_insert_rowid();
|
||||||
|
|
||||||
// Currently inserting at approximately 3400 documents a second
|
// Currently inserting at approximately 3400 documents a second
|
||||||
// I imagine we can speed this up with a bulk insert of some kind.
|
// I imagine we can speed this up with a bulk insert of some kind.
|
||||||
for document in indexed_file.documents {
|
for document in indexed_file.documents {
|
||||||
let embedding_blob = bincode::serialize(&document.embedding)?;
|
let embedding_blob = bincode::serialize(&document.embedding)?;
|
||||||
|
|
||||||
db.execute(
|
self.db.execute(
|
||||||
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
|
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
|
||||||
params![
|
params![
|
||||||
inserted_id,
|
inserted_id,
|
||||||
|
@ -109,70 +142,42 @@ impl VectorDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
|
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
|
||||||
let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
|
let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?;
|
||||||
|
let result_iter = query_statement.query_map([], |row| {
|
||||||
fn query(db: Connection) -> rusqlite::Result<Vec<FileRecord>> {
|
Ok(FileRecord {
|
||||||
let mut query_statement = db.prepare("SELECT id, path, sha1 FROM files")?;
|
id: row.get(0)?,
|
||||||
let result_iter = query_statement.query_map([], |row| {
|
path: row.get(1)?,
|
||||||
Ok(FileRecord {
|
sha1: row.get(2)?,
|
||||||
id: row.get(0)?,
|
})
|
||||||
path: row.get(1)?,
|
})?;
|
||||||
sha1: row.get(2)?,
|
|
||||||
})
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let mut results = vec![];
|
|
||||||
for result in result_iter {
|
|
||||||
results.push(result?);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok(results);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut pages: HashMap<usize, FileRecord> = HashMap::new();
|
let mut pages: HashMap<usize, FileRecord> = HashMap::new();
|
||||||
let result_iter = query(db);
|
for result in result_iter {
|
||||||
if result_iter.is_ok() {
|
let result = result?;
|
||||||
for result in result_iter.unwrap() {
|
pages.insert(result.id, result);
|
||||||
pages.insert(result.id, result);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Ok(pages);
|
Ok(pages)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
|
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
|
||||||
// Should return a HashMap in which the key is the id, and the value is the finished document
|
let mut query_statement = self
|
||||||
|
.db
|
||||||
// Get Data from Database
|
.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
|
||||||
let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
|
let result_iter = query_statement.query_map([], |row| {
|
||||||
|
Ok(DocumentRecord {
|
||||||
fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
|
id: row.get(0)?,
|
||||||
let mut query_statement =
|
file_id: row.get(1)?,
|
||||||
db.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
|
offset: row.get(2)?,
|
||||||
let result_iter = query_statement.query_map([], |row| {
|
name: row.get(3)?,
|
||||||
Ok(DocumentRecord {
|
embedding: row.get(4)?,
|
||||||
id: row.get(0)?,
|
})
|
||||||
file_id: row.get(1)?,
|
})?;
|
||||||
offset: row.get(2)?,
|
|
||||||
name: row.get(3)?,
|
|
||||||
embedding: row.get(4)?,
|
|
||||||
})
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let mut results = vec![];
|
|
||||||
for result in result_iter {
|
|
||||||
results.push(result?);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok(results);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
|
let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
|
||||||
let result_iter = query(db);
|
for result in result_iter {
|
||||||
if result_iter.is_ok() {
|
let result = result?;
|
||||||
for result in result_iter.unwrap() {
|
documents.insert(result.id, result);
|
||||||
documents.insert(result.id, result);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Ok(documents);
|
return Ok(documents);
|
||||||
|
|
|
@ -94,16 +94,6 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
response.usage.total_tokens
|
response.usage.total_tokens
|
||||||
);
|
);
|
||||||
|
|
||||||
// do we need to re-order these based on the `index` field?
|
|
||||||
eprintln!(
|
|
||||||
"indices: {:?}",
|
|
||||||
response
|
|
||||||
.data
|
|
||||||
.iter()
|
|
||||||
.map(|embedding| embedding.index)
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(response
|
Ok(response
|
||||||
.data
|
.data
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
|
@ -19,8 +19,8 @@ pub struct BruteForceSearch {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BruteForceSearch {
|
impl BruteForceSearch {
|
||||||
pub fn load() -> Result<Self> {
|
pub fn load(db: &VectorDatabase) -> Result<Self> {
|
||||||
let db = VectorDatabase {};
|
// let db = VectorDatabase {};
|
||||||
let documents = db.get_documents()?;
|
let documents = db.get_documents()?;
|
||||||
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
|
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
|
||||||
let mut document_ids = vec![];
|
let mut document_ids = vec![];
|
||||||
|
@ -47,39 +47,36 @@ impl VectorSearch for BruteForceSearch {
|
||||||
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
|
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
|
||||||
let target = Array1::from_vec(vec.to_owned());
|
let target = Array1::from_vec(vec.to_owned());
|
||||||
|
|
||||||
let distances = self.candidate_array.dot(&target);
|
let similarities = self.candidate_array.dot(&target);
|
||||||
|
|
||||||
let distances = distances.to_vec();
|
let similarities = similarities.to_vec();
|
||||||
|
|
||||||
// construct a tuple vector from the floats, the tuple being (index,float)
|
// construct a tuple vector from the floats, the tuple being (index,float)
|
||||||
let mut with_indices = distances
|
let mut with_indices = similarities
|
||||||
.clone()
|
.iter()
|
||||||
.into_iter()
|
.copied()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(index, value)| (index, value))
|
.map(|(index, value)| (self.document_ids[index], value))
|
||||||
.collect::<Vec<(usize, f32)>>();
|
.collect::<Vec<(usize, f32)>>();
|
||||||
|
|
||||||
// sort the tuple vector by float
|
// sort the tuple vector by float
|
||||||
with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) {
|
with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
|
||||||
(true, true) => Ordering::Equal,
|
with_indices.truncate(limit);
|
||||||
(true, false) => Ordering::Greater,
|
with_indices
|
||||||
(false, true) => Ordering::Less,
|
|
||||||
(false, false) => a.1.partial_cmp(&b.1).unwrap(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// extract the sorted indices from the sorted tuple vector
|
// // extract the sorted indices from the sorted tuple vector
|
||||||
let stored_indices = with_indices
|
// let stored_indices = with_indices
|
||||||
.into_iter()
|
// .into_iter()
|
||||||
.map(|(index, value)| index)
|
// .map(|(index, value)| index)
|
||||||
.collect::<Vec<usize>>();
|
// .collect::<Vec<>>();
|
||||||
|
|
||||||
let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
|
// let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
|
||||||
|
|
||||||
let mut results = vec![];
|
// let mut results = vec![];
|
||||||
for idx in sorted_indices[0..limit].to_vec() {
|
// for idx in sorted_indices[0..limit].to_vec() {
|
||||||
results.push((self.document_ids[idx], 1.0 - distances[idx]));
|
// results.push((self.document_ids[idx], 1.0 - similarities[idx]));
|
||||||
}
|
// }
|
||||||
|
|
||||||
return results;
|
// return results;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
mod db;
|
mod db;
|
||||||
mod embedding;
|
mod embedding;
|
||||||
|
mod parsing;
|
||||||
mod search;
|
mod search;
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
|
@ -7,11 +8,13 @@ use db::VectorDatabase;
|
||||||
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
|
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
|
||||||
use gpui::{AppContext, Entity, ModelContext, ModelHandle};
|
use gpui::{AppContext, Entity, ModelContext, ModelHandle};
|
||||||
use language::LanguageRegistry;
|
use language::LanguageRegistry;
|
||||||
|
use parsing::Document;
|
||||||
use project::{Fs, Project};
|
use project::{Fs, Project};
|
||||||
|
use search::{BruteForceSearch, VectorSearch};
|
||||||
use smol::channel;
|
use smol::channel;
|
||||||
use std::{path::PathBuf, sync::Arc, time::Instant};
|
use std::{path::PathBuf, sync::Arc, time::Instant};
|
||||||
use tree_sitter::{Parser, QueryCursor};
|
use tree_sitter::{Parser, QueryCursor};
|
||||||
use util::{http::HttpClient, ResultExt};
|
use util::{http::HttpClient, ResultExt, TryFutureExt};
|
||||||
use workspace::WorkspaceCreated;
|
use workspace::WorkspaceCreated;
|
||||||
|
|
||||||
pub fn init(
|
pub fn init(
|
||||||
|
@ -39,13 +42,6 @@ pub fn init(
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Document {
|
|
||||||
pub offset: usize,
|
|
||||||
pub name: String,
|
|
||||||
pub embedding: Vec<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct IndexedFile {
|
pub struct IndexedFile {
|
||||||
path: PathBuf,
|
path: PathBuf,
|
||||||
|
@ -180,18 +176,54 @@ impl VectorStore {
|
||||||
.detach();
|
.detach();
|
||||||
|
|
||||||
cx.background()
|
cx.background()
|
||||||
.spawn(async move {
|
.spawn({
|
||||||
|
let client = client.clone();
|
||||||
|
async move {
|
||||||
// Initialize Database, creates database and tables if not exists
|
// Initialize Database, creates database and tables if not exists
|
||||||
VectorDatabase::initialize_database().await.log_err();
|
let db = VectorDatabase::new()?;
|
||||||
while let Ok(indexed_file) = indexed_files_rx.recv().await {
|
while let Ok(indexed_file) = indexed_files_rx.recv().await {
|
||||||
VectorDatabase::insert_file(indexed_file).await.log_err();
|
db.insert_file(indexed_file).log_err();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ALL OF THE BELOW IS FOR TESTING,
|
||||||
|
// This should be removed as we find and appropriate place for evaluate our search.
|
||||||
|
|
||||||
|
let embedding_provider = OpenAIEmbeddings{ client };
|
||||||
|
let queries = vec![
|
||||||
|
"compute embeddings for all of the symbols in the codebase, and write them to a database",
|
||||||
|
"compute an outline view of all of the symbols in a buffer",
|
||||||
|
"scan a directory on the file system and load all of its children into an in-memory snapshot",
|
||||||
|
];
|
||||||
|
let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
|
||||||
|
|
||||||
|
let t2 = Instant::now();
|
||||||
|
let documents = db.get_documents().unwrap();
|
||||||
|
let files = db.get_files().unwrap();
|
||||||
|
println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
|
||||||
|
|
||||||
|
let t1 = Instant::now();
|
||||||
|
let mut bfs = BruteForceSearch::load(&db).unwrap();
|
||||||
|
println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
|
||||||
|
for (idx, embed) in embeddings.into_iter().enumerate() {
|
||||||
|
let t0 = Instant::now();
|
||||||
|
println!("\nQuery: {:?}", queries[idx]);
|
||||||
|
let results = bfs.top_k_search(&embed, 5).await;
|
||||||
|
println!("Search Elapsed: {}", t0.elapsed().as_millis());
|
||||||
|
for (id, distance) in results {
|
||||||
|
println!("");
|
||||||
|
println!(" distance: {:?}", distance);
|
||||||
|
println!(" document: {:?}", documents[&id].name);
|
||||||
|
println!(" path: {:?}", files[&documents[&id].file_id].path);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::Ok(())
|
anyhow::Ok(())
|
||||||
})
|
}}.log_err())
|
||||||
.detach();
|
.detach();
|
||||||
|
|
||||||
let provider = DummyEmbeddings {};
|
let provider = DummyEmbeddings {};
|
||||||
|
// let provider = OpenAIEmbeddings { client };
|
||||||
|
|
||||||
cx.background()
|
cx.background()
|
||||||
.scoped(|scope| {
|
.scoped(|scope| {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue