updated embedding database calls to maintain project consistency
Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
parent
400d39740c
commit
85e71415fe
3 changed files with 0 additions and 111 deletions
|
@ -236,27 +236,6 @@ impl VectorDatabase {
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
|
||||
let mut query_statement = self
|
||||
.db
|
||||
.prepare("SELECT id, relative_path, sha1 FROM files")?;
|
||||
let result_iter = query_statement.query_map([], |row| {
|
||||
Ok(FileRecord {
|
||||
id: row.get(0)?,
|
||||
relative_path: row.get(1)?,
|
||||
sha1: row.get(2)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut pages: HashMap<usize, FileRecord> = HashMap::new();
|
||||
for result in result_iter {
|
||||
let result = result?;
|
||||
pages.insert(result.id, result);
|
||||
}
|
||||
|
||||
Ok(pages)
|
||||
}
|
||||
|
||||
pub fn for_each_document(
|
||||
&self,
|
||||
worktree_ids: &[i64],
|
||||
|
@ -321,29 +300,6 @@ impl VectorDatabase {
|
|||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
|
||||
let mut query_statement = self
|
||||
.db
|
||||
.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
|
||||
let result_iter = query_statement.query_map([], |row| {
|
||||
Ok(DocumentRecord {
|
||||
id: row.get(0)?,
|
||||
file_id: row.get(1)?,
|
||||
offset: row.get(2)?,
|
||||
name: row.get(3)?,
|
||||
embedding: row.get(4)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
|
||||
for result in result_iter {
|
||||
let result = result?;
|
||||
documents.insert(result.id, result);
|
||||
}
|
||||
|
||||
return Ok(documents);
|
||||
}
|
||||
}
|
||||
|
||||
fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
use std::{cmp::Ordering, path::PathBuf};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use ndarray::{Array1, Array2};
|
||||
|
||||
use crate::db::{DocumentRecord, VectorDatabase};
|
||||
use anyhow::Result;
|
||||
|
||||
#[async_trait]
|
||||
pub trait VectorSearch {
|
||||
// Given a query vector, and a limit to return
|
||||
// Return a vector of id, distance tuples.
|
||||
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)>;
|
||||
}
|
||||
|
||||
pub struct BruteForceSearch {
|
||||
document_ids: Vec<usize>,
|
||||
candidate_array: ndarray::Array2<f32>,
|
||||
}
|
||||
|
||||
impl BruteForceSearch {
|
||||
pub fn load(db: &VectorDatabase) -> Result<Self> {
|
||||
let documents = db.get_documents()?;
|
||||
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
|
||||
let mut document_ids = vec![];
|
||||
for i in documents.keys() {
|
||||
document_ids.push(i.to_owned());
|
||||
}
|
||||
|
||||
let mut candidate_array = Array2::<f32>::default((documents.len(), 1536));
|
||||
for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() {
|
||||
for (j, col) in row.iter_mut().enumerate() {
|
||||
*col = embeddings[i].embedding.0[j];
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(BruteForceSearch {
|
||||
document_ids,
|
||||
candidate_array,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VectorSearch for BruteForceSearch {
|
||||
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
|
||||
let target = Array1::from_vec(vec.to_owned());
|
||||
|
||||
let similarities = self.candidate_array.dot(&target);
|
||||
|
||||
let similarities = similarities.to_vec();
|
||||
|
||||
// construct a tuple vector from the floats, the tuple being (index,float)
|
||||
let mut with_indices = similarities
|
||||
.iter()
|
||||
.copied()
|
||||
.enumerate()
|
||||
.map(|(index, value)| (self.document_ids[index], value))
|
||||
.collect::<Vec<(usize, f32)>>();
|
||||
|
||||
// sort the tuple vector by float
|
||||
with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
|
||||
with_indices.truncate(limit);
|
||||
with_indices
|
||||
}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
mod db;
|
||||
mod embedding;
|
||||
mod modal;
|
||||
mod search;
|
||||
|
||||
#[cfg(test)]
|
||||
mod vector_store_tests;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue