Reify Embedding/Sha1 structs that can be (de)serialized from SQL

Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-08-31 17:55:43 +02:00
parent c763e728d1
commit 3001a46f69
5 changed files with 180 additions and 138 deletions

View file

@ -1,13 +1,10 @@
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION};
use anyhow::{anyhow, Context, Result};
use futures::channel::oneshot;
use gpui::executor;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
use rusqlite::{
params,
types::{FromSql, FromSqlResult, ValueRef},
};
use rusqlite::params;
use std::{
cmp::Ordering,
collections::HashMap,
@ -27,34 +24,6 @@ pub struct FileRecord {
pub mtime: Timestamp,
}
#[derive(Debug)]
struct Embedding(pub Vec<f32>);
#[derive(Debug)]
struct Sha1(pub Vec<u8>);
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
return Ok(Embedding(embedding.unwrap()));
}
}
impl FromSql for Sha1 {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if sha1.is_err() {
return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
}
return Ok(Sha1(sha1.unwrap()));
}
}
#[derive(Clone)]
pub struct VectorDatabase {
path: Arc<Path>,
@ -255,9 +224,6 @@ impl VectorDatabase {
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in documents {
let embedding_blob = bincode::serialize(&document.embedding)?;
let sha_blob = bincode::serialize(&document.sha1)?;
db.execute(
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
@ -265,8 +231,8 @@ impl VectorDatabase {
document.range.start.to_string(),
document.range.end.to_string(),
document.name,
embedding_blob,
sha_blob
document.embedding,
document.sha1
],
)?;
}
@ -351,7 +317,7 @@ impl VectorDatabase {
pub fn top_k_search(
&self,
query_embedding: &Vec<f32>,
query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
@ -360,7 +326,7 @@ impl VectorDatabase {
self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
Self::for_each_document(db, &file_ids, |id, embedding| {
let similarity = dot(&embedding, &query_embedding);
let similarity = embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
@ -417,7 +383,7 @@ impl VectorDatabase {
fn for_each_document(
db: &rusqlite::Connection,
file_ids: &[i64],
mut f: impl FnMut(i64, Vec<f32>),
mut f: impl FnMut(i64, Embedding),
) -> Result<()> {
let mut query_statement = db.prepare(
"
@ -435,7 +401,7 @@ impl VectorDatabase {
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
.for_each(|(id, embedding)| f(id, embedding.0));
.for_each(|(id, embedding)| f(id, embedding));
Ok(())
}
@ -497,29 +463,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
.collect::<Vec<_>>(),
)
}
pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
let len = vec_a.len();
assert_eq!(len, vec_b.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
vec_a.as_ptr(),
len as isize,
1,
vec_b.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}