added proper blob serialization for embeddings and vector search trait

This commit is contained in:
KCaverly 2023-06-25 20:02:56 -04:00
parent c071b271be
commit 65bbb7c57b
6 changed files with 104 additions and 39 deletions

13
Cargo.lock generated
View file

@ -1768,9 +1768,9 @@ dependencies = [
[[package]]
name = "cxx"
version = "1.0.94"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f61f1b6389c3fe1c316bf8a4dccc90a38208354b330925bce1f74a6c4756eb93"
checksum = "e88abab2f5abbe4c56e8f1fb431b784d710b709888f35755a160e62e33fe38e8"
dependencies = [
"cc",
"cxxbridge-flags",
@ -1795,15 +1795,15 @@ dependencies = [
[[package]]
name = "cxxbridge-flags"
version = "1.0.94"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7944172ae7e4068c533afbb984114a56c46e9ccddda550499caa222902c7f7bb"
checksum = "8d3816ed957c008ccd4728485511e3d9aaf7db419aa321e3d2c5a2f3411e36c8"
[[package]]
name = "cxxbridge-macro"
version = "1.0.94"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2345488264226bf682893e25de0769f3360aac9957980ec49361b083ddaa5bc5"
checksum = "a26acccf6f445af85ea056362561a24ef56cdc15fcc685f03aec50b9c702cb6d"
dependencies = [
"proc-macro2",
"quote",
@ -7913,6 +7913,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bincode",
"futures 0.3.28",
"gpui",
"isahc",

View file

@ -17,7 +17,7 @@ util = { path = "../util" }
anyhow.workspace = true
futures.workspace = true
smol.workspace = true
rusqlite = "0.27.0"
rusqlite = { version = "0.27.0", features=["blob"] }
isahc.workspace = true
log.workspace = true
tree-sitter.workspace = true
@ -25,6 +25,7 @@ lazy_static.workspace = true
serde.workspace = true
serde_json.workspace = true
async-trait.workspace = true
bincode = "1.3.3"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View file

@ -1,13 +1,44 @@
use anyhow::Result;
use rusqlite::params;
use std::collections::HashMap;
use crate::IndexedFile;
use anyhow::{anyhow, Result};
use rusqlite::{
params,
types::{FromSql, FromSqlResult, ValueRef},
Connection,
};
use util::ResultExt;
use crate::{Document, IndexedFile};
// This is saving to a local database store within the users dev zed path
// Where do we want this to sit?
// Assuming near where the workspace DB sits.
const VECTOR_DB_URL: &str = "embeddings_db";
// Note this is not an appropriate document
#[derive(Debug)]
pub struct DocumentRecord {
id: usize,
offset: usize,
name: String,
embedding: Embedding,
}
#[derive(Debug)]
struct Embedding(Vec<f32>);
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
return Ok(Embedding(embedding.unwrap()));
}
}
pub struct VectorDatabase {}
impl VectorDatabase {
@ -51,37 +82,66 @@ impl VectorDatabase {
let inserted_id = db.last_insert_rowid();
// I stole this from https://stackoverflow.com/questions/71829931/how-do-i-convert-a-negative-f32-value-to-binary-string-and-back-again
// I imagine there is a better way to serialize to/from blob
fn get_binary_from_values(values: Vec<f32>) -> String {
let bits: Vec<_> = values.iter().map(|v| v.to_bits().to_string()).collect();
bits.join(";")
}
fn get_values_from_binary(bin: &str) -> Vec<f32> {
(0..bin.len() / 32)
.map(|i| {
let start = i * 32;
let end = start + 32;
f32::from_bits(u32::from_str_radix(&bin[start..end], 2).unwrap())
})
.collect()
}
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in indexed_file.documents {
let embedding_blob = bincode::serialize(&document.embedding)?;
db.execute(
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
params![
inserted_id,
document.offset.to_string(),
document.name,
get_binary_from_values(document.embedding)
embedding_blob
],
)?;
}
Ok(())
}
pub fn get_documents(&self) -> Result<HashMap<usize, Document>> {
// Should return a HashMap in which the key is the id, and the value is the finished document
// Get Data from Database
let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
let mut query_statement =
db.prepare("SELECT id, offset, name, embedding FROM documents LIMIT 10")?;
let result_iter = query_statement.query_map([], |row| {
Ok(DocumentRecord {
id: row.get(0)?,
offset: row.get(1)?,
name: row.get(2)?,
embedding: row.get(3)?,
})
})?;
let mut results = vec![];
for result in result_iter {
results.push(result?);
}
return Ok(results);
}
let mut documents: HashMap<usize, Document> = HashMap::new();
let result_iter = query(db);
if result_iter.is_ok() {
for result in result_iter.unwrap() {
documents.insert(
result.id,
Document {
offset: result.offset,
name: result.name,
embedding: result.embedding.0,
},
);
}
}
return Ok(documents);
}
}

View file

@ -13,6 +13,7 @@ lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
}
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
}
@ -54,7 +55,7 @@ impl EmbeddingProvider for DummyEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = vec![0.32 as f32; 1024];
let dummy_vec = vec![0.32 as f32; 1536];
return Ok(vec![dummy_vec; spans.len()]);
}
}

View file

@ -0,0 +1,5 @@
trait VectorSearch {
// Given a query vector, and a limit to return
// Return a vector of id, distance tuples.
fn top_k_search(&self, vec: &Vec<f32>) -> Vec<(usize, f32)>;
}

View file

@ -1,5 +1,6 @@
mod db;
mod embedding;
mod search;
use anyhow::{anyhow, Result};
use db::VectorDatabase;
@ -39,10 +40,10 @@ pub fn init(
}
#[derive(Debug)]
struct Document {
offset: usize,
name: String,
embedding: Vec<f32>,
pub struct Document {
pub offset: usize,
pub name: String,
pub embedding: Vec<f32>,
}
#[derive(Debug)]
@ -185,14 +186,13 @@ impl VectorStore {
while let Ok(indexed_file) = indexed_files_rx.recv().await {
VectorDatabase::insert_file(indexed_file).await.log_err();
}
anyhow::Ok(())
})
.detach();
// let provider = OpenAIEmbeddings { client };
let provider = DummyEmbeddings {};
let t0 = Instant::now();
cx.background()
.scoped(|scope| {
for _ in 0..cx.background().num_cpus() {
@ -218,9 +218,6 @@ impl VectorStore {
}
})
.await;
let duration = t0.elapsed();
log::info!("indexed project in {duration:?}");
})
.detach();
}