added proper blob serialization for embeddings and vector search trait
This commit is contained in:
parent
c071b271be
commit
65bbb7c57b
6 changed files with 104 additions and 39 deletions
13
Cargo.lock
generated
13
Cargo.lock
generated
|
@ -1768,9 +1768,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cxx"
|
||||
version = "1.0.94"
|
||||
version = "1.0.97"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f61f1b6389c3fe1c316bf8a4dccc90a38208354b330925bce1f74a6c4756eb93"
|
||||
checksum = "e88abab2f5abbe4c56e8f1fb431b784d710b709888f35755a160e62e33fe38e8"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cxxbridge-flags",
|
||||
|
@ -1795,15 +1795,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cxxbridge-flags"
|
||||
version = "1.0.94"
|
||||
version = "1.0.97"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7944172ae7e4068c533afbb984114a56c46e9ccddda550499caa222902c7f7bb"
|
||||
checksum = "8d3816ed957c008ccd4728485511e3d9aaf7db419aa321e3d2c5a2f3411e36c8"
|
||||
|
||||
[[package]]
|
||||
name = "cxxbridge-macro"
|
||||
version = "1.0.94"
|
||||
version = "1.0.97"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2345488264226bf682893e25de0769f3360aac9957980ec49361b083ddaa5bc5"
|
||||
checksum = "a26acccf6f445af85ea056362561a24ef56cdc15fcc685f03aec50b9c702cb6d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -7913,6 +7913,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"isahc",
|
||||
|
|
|
@ -17,7 +17,7 @@ util = { path = "../util" }
|
|||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
smol.workspace = true
|
||||
rusqlite = "0.27.0"
|
||||
rusqlite = { version = "0.27.0", features=["blob"] }
|
||||
isahc.workspace = true
|
||||
log.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
|
@ -25,6 +25,7 @@ lazy_static.workspace = true
|
|||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
async-trait.workspace = true
|
||||
bincode = "1.3.3"
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
|
|
|
@ -1,13 +1,44 @@
|
|||
use anyhow::Result;
|
||||
use rusqlite::params;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::IndexedFile;
|
||||
use anyhow::{anyhow, Result};
|
||||
|
||||
use rusqlite::{
|
||||
params,
|
||||
types::{FromSql, FromSqlResult, ValueRef},
|
||||
Connection,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{Document, IndexedFile};
|
||||
|
||||
// This is saving to a local database store within the users dev zed path
|
||||
// Where do we want this to sit?
|
||||
// Assuming near where the workspace DB sits.
|
||||
const VECTOR_DB_URL: &str = "embeddings_db";
|
||||
|
||||
// Note this is not an appropriate document
|
||||
#[derive(Debug)]
|
||||
pub struct DocumentRecord {
|
||||
id: usize,
|
||||
offset: usize,
|
||||
name: String,
|
||||
embedding: Embedding,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding(Vec<f32>);
|
||||
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
return Ok(Embedding(embedding.unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorDatabase {}
|
||||
|
||||
impl VectorDatabase {
|
||||
|
@ -51,37 +82,66 @@ impl VectorDatabase {
|
|||
|
||||
let inserted_id = db.last_insert_rowid();
|
||||
|
||||
// I stole this from https://stackoverflow.com/questions/71829931/how-do-i-convert-a-negative-f32-value-to-binary-string-and-back-again
|
||||
// I imagine there is a better way to serialize to/from blob
|
||||
fn get_binary_from_values(values: Vec<f32>) -> String {
|
||||
let bits: Vec<_> = values.iter().map(|v| v.to_bits().to_string()).collect();
|
||||
bits.join(";")
|
||||
}
|
||||
|
||||
fn get_values_from_binary(bin: &str) -> Vec<f32> {
|
||||
(0..bin.len() / 32)
|
||||
.map(|i| {
|
||||
let start = i * 32;
|
||||
let end = start + 32;
|
||||
f32::from_bits(u32::from_str_radix(&bin[start..end], 2).unwrap())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Currently inserting at approximately 3400 documents a second
|
||||
// I imagine we can speed this up with a bulk insert of some kind.
|
||||
for document in indexed_file.documents {
|
||||
let embedding_blob = bincode::serialize(&document.embedding)?;
|
||||
|
||||
db.execute(
|
||||
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
|
||||
params![
|
||||
inserted_id,
|
||||
document.offset.to_string(),
|
||||
document.name,
|
||||
get_binary_from_values(document.embedding)
|
||||
embedding_blob
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_documents(&self) -> Result<HashMap<usize, Document>> {
|
||||
// Should return a HashMap in which the key is the id, and the value is the finished document
|
||||
|
||||
// Get Data from Database
|
||||
let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
|
||||
|
||||
fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
|
||||
let mut query_statement =
|
||||
db.prepare("SELECT id, offset, name, embedding FROM documents LIMIT 10")?;
|
||||
let result_iter = query_statement.query_map([], |row| {
|
||||
Ok(DocumentRecord {
|
||||
id: row.get(0)?,
|
||||
offset: row.get(1)?,
|
||||
name: row.get(2)?,
|
||||
embedding: row.get(3)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut results = vec![];
|
||||
for result in result_iter {
|
||||
results.push(result?);
|
||||
}
|
||||
|
||||
return Ok(results);
|
||||
}
|
||||
|
||||
let mut documents: HashMap<usize, Document> = HashMap::new();
|
||||
let result_iter = query(db);
|
||||
if result_iter.is_ok() {
|
||||
for result in result_iter.unwrap() {
|
||||
documents.insert(
|
||||
result.id,
|
||||
Document {
|
||||
offset: result.offset,
|
||||
name: result.name,
|
||||
embedding: result.embedding.0,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(documents);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ lazy_static! {
|
|||
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddings {
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
@ -54,7 +55,7 @@ impl EmbeddingProvider for DummyEmbeddings {
|
|||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
// 1024 is the OpenAI Embeddings size for ada models.
|
||||
// the model we will likely be starting with.
|
||||
let dummy_vec = vec![0.32 as f32; 1024];
|
||||
let dummy_vec = vec![0.32 as f32; 1536];
|
||||
return Ok(vec![dummy_vec; spans.len()]);
|
||||
}
|
||||
}
|
||||
|
|
5
crates/vector_store/src/search.rs
Normal file
5
crates/vector_store/src/search.rs
Normal file
|
@ -0,0 +1,5 @@
|
|||
trait VectorSearch {
|
||||
// Given a query vector, and a limit to return
|
||||
// Return a vector of id, distance tuples.
|
||||
fn top_k_search(&self, vec: &Vec<f32>) -> Vec<(usize, f32)>;
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
mod db;
|
||||
mod embedding;
|
||||
mod search;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use db::VectorDatabase;
|
||||
|
@ -39,10 +40,10 @@ pub fn init(
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Document {
|
||||
offset: usize,
|
||||
name: String,
|
||||
embedding: Vec<f32>,
|
||||
pub struct Document {
|
||||
pub offset: usize,
|
||||
pub name: String,
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -185,14 +186,13 @@ impl VectorStore {
|
|||
while let Ok(indexed_file) = indexed_files_rx.recv().await {
|
||||
VectorDatabase::insert_file(indexed_file).await.log_err();
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
// let provider = OpenAIEmbeddings { client };
|
||||
let provider = DummyEmbeddings {};
|
||||
|
||||
let t0 = Instant::now();
|
||||
|
||||
cx.background()
|
||||
.scoped(|scope| {
|
||||
for _ in 0..cx.background().num_cpus() {
|
||||
|
@ -218,9 +218,6 @@ impl VectorStore {
|
|||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let duration = t0.elapsed();
|
||||
log::info!("indexed project in {duration:?}");
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue