updated embedding database calls to maintain project consistency

Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
KCaverly 2023-06-28 16:25:05 -04:00
parent 400d39740c
commit 85e71415fe
3 changed files with 0 additions and 111 deletions

View file

@ -236,27 +236,6 @@ impl VectorDatabase {
Ok(result)
}
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
let mut query_statement = self
.db
.prepare("SELECT id, relative_path, sha1 FROM files")?;
let result_iter = query_statement.query_map([], |row| {
Ok(FileRecord {
id: row.get(0)?,
relative_path: row.get(1)?,
sha1: row.get(2)?,
})
})?;
let mut pages: HashMap<usize, FileRecord> = HashMap::new();
for result in result_iter {
let result = result?;
pages.insert(result.id, result);
}
Ok(pages)
}
pub fn for_each_document(
&self,
worktree_ids: &[i64],
@ -321,29 +300,6 @@ impl VectorDatabase {
Ok(results)
}
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
let mut query_statement = self
.db
.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
let result_iter = query_statement.query_map([], |row| {
Ok(DocumentRecord {
id: row.get(0)?,
file_id: row.get(1)?,
offset: row.get(2)?,
name: row.get(3)?,
embedding: row.get(4)?,
})
})?;
let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
for result in result_iter {
let result = result?;
documents.insert(result.id, result);
}
return Ok(documents);
}
}
fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {

View file

@ -1,66 +0,0 @@
use std::{cmp::Ordering, path::PathBuf};
use async_trait::async_trait;
use ndarray::{Array1, Array2};
use crate::db::{DocumentRecord, VectorDatabase};
use anyhow::Result;
#[async_trait]
pub trait VectorSearch {
// Given a query vector, and a limit to return
// Return a vector of id, distance tuples.
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)>;
}
pub struct BruteForceSearch {
document_ids: Vec<usize>,
candidate_array: ndarray::Array2<f32>,
}
impl BruteForceSearch {
pub fn load(db: &VectorDatabase) -> Result<Self> {
let documents = db.get_documents()?;
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
let mut document_ids = vec![];
for i in documents.keys() {
document_ids.push(i.to_owned());
}
let mut candidate_array = Array2::<f32>::default((documents.len(), 1536));
for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() {
for (j, col) in row.iter_mut().enumerate() {
*col = embeddings[i].embedding.0[j];
}
}
return Ok(BruteForceSearch {
document_ids,
candidate_array,
});
}
}
#[async_trait]
impl VectorSearch for BruteForceSearch {
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
let target = Array1::from_vec(vec.to_owned());
let similarities = self.candidate_array.dot(&target);
let similarities = similarities.to_vec();
// construct a tuple vector from the floats, the tuple being (index,float)
let mut with_indices = similarities
.iter()
.copied()
.enumerate()
.map(|(index, value)| (self.document_ids[index], value))
.collect::<Vec<(usize, f32)>>();
// sort the tuple vector by float
with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
with_indices.truncate(limit);
with_indices
}
}

View file

@ -1,7 +1,6 @@
mod db;
mod embedding;
mod modal;
mod search;
#[cfg(test)]
mod vector_store_tests;