added brute force search and VectorSearch trait

This commit is contained in:
KCaverly 2023-06-26 10:34:12 -04:00
parent 65bbb7c57b
commit 7937a16002
3 changed files with 122 additions and 2 deletions

View file

@ -26,6 +26,7 @@ serde.workspace = true
serde_json.workspace = true
async-trait.workspace = true
bincode = "1.3.3"
ndarray = "0.15.6"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View file

@ -1,5 +1,85 @@
trait VectorSearch {
use std::cmp::Ordering;
use async_trait::async_trait;
use ndarray::{Array1, Array2};
use crate::db::{DocumentRecord, VectorDatabase};
use anyhow::Result;
#[async_trait]
pub trait VectorSearch {
// Given a query vector, and a limit to return
// Return a vector of id, distance tuples.
fn top_k_search(&self, vec: &Vec<f32>) -> Vec<(usize, f32)>;
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)>;
}
pub struct BruteForceSearch {
document_ids: Vec<usize>,
candidate_array: ndarray::Array2<f32>,
}
impl BruteForceSearch {
pub fn load() -> Result<Self> {
let db = VectorDatabase {};
let documents = db.get_documents()?;
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
let mut document_ids = vec![];
for i in documents.keys() {
document_ids.push(i.to_owned());
}
let mut candidate_array = Array2::<f32>::default((documents.len(), 1536));
for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() {
for (j, col) in row.iter_mut().enumerate() {
*col = embeddings[i].embedding.0[j];
}
}
return Ok(BruteForceSearch {
document_ids,
candidate_array,
});
}
}
#[async_trait]
impl VectorSearch for BruteForceSearch {
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
let target = Array1::from_vec(vec.to_owned());
let distances = self.candidate_array.dot(&target);
let distances = distances.to_vec();
// construct a tuple vector from the floats, the tuple being (index,float)
let mut with_indices = distances
.clone()
.into_iter()
.enumerate()
.map(|(index, value)| (index, value))
.collect::<Vec<(usize, f32)>>();
// sort the tuple vector by float
with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
(false, false) => a.1.partial_cmp(&b.1).unwrap(),
});
// extract the sorted indices from the sorted tuple vector
let stored_indices = with_indices
.into_iter()
.map(|(index, value)| index)
.collect::<Vec<usize>>();
let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
let mut results = vec![];
for idx in sorted_indices[0..limit].to_vec() {
results.push((self.document_ids[idx], 1.0 - distances[idx]));
}
return results;
}
}