Reify Embedding
/Sha1
structs that can be (de)serialized from SQL
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
parent
c763e728d1
commit
3001a46f69
5 changed files with 180 additions and 138 deletions
|
@ -1,13 +1,10 @@
|
|||
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
|
||||
use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::executor;
|
||||
use project::{search::PathMatcher, Fs};
|
||||
use rpc::proto::Timestamp;
|
||||
use rusqlite::{
|
||||
params,
|
||||
types::{FromSql, FromSqlResult, ValueRef},
|
||||
};
|
||||
use rusqlite::params;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::HashMap,
|
||||
|
@ -27,34 +24,6 @@ pub struct FileRecord {
|
|||
pub mtime: Timestamp,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding(pub Vec<f32>);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Sha1(pub Vec<u8>);
|
||||
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
return Ok(Embedding(embedding.unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for Sha1 {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if sha1.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
|
||||
}
|
||||
return Ok(Sha1(sha1.unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VectorDatabase {
|
||||
path: Arc<Path>,
|
||||
|
@ -255,9 +224,6 @@ impl VectorDatabase {
|
|||
// Currently inserting at approximately 3400 documents a second
|
||||
// I imagine we can speed this up with a bulk insert of some kind.
|
||||
for document in documents {
|
||||
let embedding_blob = bincode::serialize(&document.embedding)?;
|
||||
let sha_blob = bincode::serialize(&document.sha1)?;
|
||||
|
||||
db.execute(
|
||||
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![
|
||||
|
@ -265,8 +231,8 @@ impl VectorDatabase {
|
|||
document.range.start.to_string(),
|
||||
document.range.end.to_string(),
|
||||
document.name,
|
||||
embedding_blob,
|
||||
sha_blob
|
||||
document.embedding,
|
||||
document.sha1
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
@ -351,7 +317,7 @@ impl VectorDatabase {
|
|||
|
||||
pub fn top_k_search(
|
||||
&self,
|
||||
query_embedding: &Vec<f32>,
|
||||
query_embedding: &Embedding,
|
||||
limit: usize,
|
||||
file_ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
|
||||
|
@ -360,7 +326,7 @@ impl VectorDatabase {
|
|||
self.transact(move |db| {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
Self::for_each_document(db, &file_ids, |id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
let similarity = embedding.similarity(&query_embedding);
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
|
@ -417,7 +383,7 @@ impl VectorDatabase {
|
|||
fn for_each_document(
|
||||
db: &rusqlite::Connection,
|
||||
file_ids: &[i64],
|
||||
mut f: impl FnMut(i64, Vec<f32>),
|
||||
mut f: impl FnMut(i64, Embedding),
|
||||
) -> Result<()> {
|
||||
let mut query_statement = db.prepare(
|
||||
"
|
||||
|
@ -435,7 +401,7 @@ impl VectorDatabase {
|
|||
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?
|
||||
.filter_map(|row| row.ok())
|
||||
.for_each(|(id, embedding)| f(id, embedding.0));
|
||||
.for_each(|(id, embedding)| f(id, embedding));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -497,29 +463,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
|
|||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
|
||||
let len = vec_a.len();
|
||||
assert_eq!(len, vec_b.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
vec_a.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
vec_b.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue