WIP: Got the streaming matrix multiplication working, and started work on file hashing.

Co-authored-by: maxbrunsfeld <max@zed.dev>
This commit is contained in:
KCaverly 2023-06-26 19:01:19 -04:00
parent 74b693d6b9
commit 953e928bdb
7 changed files with 396 additions and 97 deletions

5
Cargo.lock generated
View file

@ -7958,13 +7958,18 @@ dependencies = [
"language", "language",
"lazy_static", "lazy_static",
"log", "log",
"matrixmultiply",
"ndarray", "ndarray",
"project", "project",
"rand 0.8.5",
"rusqlite", "rusqlite",
"serde", "serde",
"serde_json", "serde_json",
"sha-1 0.10.1",
"smol", "smol",
"tree-sitter", "tree-sitter",
"tree-sitter-rust",
"unindent",
"util", "util",
"workspace", "workspace",
] ]

View file

@ -27,9 +27,14 @@ serde_json.workspace = true
async-trait.workspace = true async-trait.workspace = true
bincode = "1.3.3" bincode = "1.3.3"
ndarray = "0.15.6" ndarray = "0.15.6"
sha-1 = "0.10.1"
matrixmultiply = "0.3.7"
[dev-dependencies] [dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] } language = { path = "../language", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] } project = { path = "../project", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] }
tree-sitter-rust = "*"
rand.workspace = true
unindent.workspace = true

View file

@ -1,4 +1,7 @@
use std::{collections::HashMap, path::PathBuf}; use std::{
collections::HashMap,
path::{Path, PathBuf},
};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
@ -13,7 +16,7 @@ use crate::IndexedFile;
// This is saving to a local database store within the users dev zed path // This is saving to a local database store within the users dev zed path
// Where do we want this to sit? // Where do we want this to sit?
// Assuming near where the workspace DB sits. // Assuming near where the workspace DB sits.
const VECTOR_DB_URL: &str = "embeddings_db"; pub const VECTOR_DB_URL: &str = "embeddings_db";
// Note this is not an appropriate document // Note this is not an appropriate document
#[derive(Debug)] #[derive(Debug)]
@ -28,7 +31,7 @@ pub struct DocumentRecord {
#[derive(Debug)] #[derive(Debug)]
pub struct FileRecord { pub struct FileRecord {
pub id: usize, pub id: usize,
pub path: String, pub relative_path: String,
pub sha1: String, pub sha1: String,
} }
@ -51,9 +54,9 @@ pub struct VectorDatabase {
} }
impl VectorDatabase { impl VectorDatabase {
pub fn new() -> Result<Self> { pub fn new(path: &str) -> Result<Self> {
let this = Self { let this = Self {
db: rusqlite::Connection::open(VECTOR_DB_URL)?, db: rusqlite::Connection::open(path)?,
}; };
this.initialize_database()?; this.initialize_database()?;
Ok(this) Ok(this)
@ -63,21 +66,23 @@ impl VectorDatabase {
// This will create the database if it doesnt exist // This will create the database if it doesnt exist
// Initialize Vector Databasing Tables // Initialize Vector Databasing Tables
// self.db.execute( self.db.execute(
// " "CREATE TABLE IF NOT EXISTS worktrees (
// CREATE TABLE IF NOT EXISTS projects ( id INTEGER PRIMARY KEY AUTOINCREMENT,
// id INTEGER PRIMARY KEY AUTOINCREMENT, absolute_path VARCHAR NOT NULL
// path NVARCHAR(100) NOT NULL );
// ) CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
// ", ",
// [], [],
// )?; )?;
self.db.execute( self.db.execute(
"CREATE TABLE IF NOT EXISTS files ( "CREATE TABLE IF NOT EXISTS files (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
path NVARCHAR(100) NOT NULL, worktree_id INTEGER NOT NULL,
sha1 NVARCHAR(40) NOT NULL relative_path VARCHAR NOT NULL,
sha1 NVARCHAR(40) NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)", )",
[], [],
)?; )?;
@ -87,7 +92,7 @@ impl VectorDatabase {
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL, file_id INTEGER NOT NULL,
offset INTEGER NOT NULL, offset INTEGER NOT NULL,
name NVARCHAR(100) NOT NULL, name VARCHAR NOT NULL,
embedding BLOB NOT NULL, embedding BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)", )",
@ -116,7 +121,7 @@ impl VectorDatabase {
pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> { pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
// Write to files table, and return generated id. // Write to files table, and return generated id.
let files_insert = self.db.execute( let files_insert = self.db.execute(
"INSERT INTO files (path, sha1) VALUES (?1, ?2)", "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
params![indexed_file.path.to_str(), indexed_file.sha1], params![indexed_file.path.to_str(), indexed_file.sha1],
)?; )?;
@ -141,12 +146,38 @@ impl VectorDatabase {
Ok(()) Ok(())
} }
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
self.db.execute(
"
INSERT into worktrees (absolute_path) VALUES (?1)
ON CONFLICT DO NOTHING
",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(self.db.last_insert_rowid())
}
pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
let mut statement = self
.db
.prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
let mut result = Vec::new();
for row in
statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
{
result.push(row?);
}
Ok(result)
}
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> { pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?; let mut query_statement = self
.db
.prepare("SELECT id, relative_path, sha1 FROM files")?;
let result_iter = query_statement.query_map([], |row| { let result_iter = query_statement.query_map([], |row| {
Ok(FileRecord { Ok(FileRecord {
id: row.get(0)?, id: row.get(0)?,
path: row.get(1)?, relative_path: row.get(1)?,
sha1: row.get(2)?, sha1: row.get(2)?,
}) })
})?; })?;
@ -160,6 +191,19 @@ impl VectorDatabase {
Ok(pages) Ok(pages)
} }
pub fn for_each_document(
&self,
worktree_id: i64,
mut f: impl FnMut(i64, Embedding),
) -> Result<()> {
let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
query_statement
.query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
.filter_map(|row| row.ok())
.for_each(|row| f(row.0, row.1));
Ok(())
}
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> { pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
let mut query_statement = self let mut query_statement = self
.db .db

View file

@ -44,7 +44,7 @@ struct OpenAIEmbeddingUsage {
} }
#[async_trait] #[async_trait]
pub trait EmbeddingProvider: Sync { pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>; async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
} }

View file

@ -1,4 +1,4 @@
use std::cmp::Ordering; use std::{cmp::Ordering, path::PathBuf};
use async_trait::async_trait; use async_trait::async_trait;
use ndarray::{Array1, Array2}; use ndarray::{Array1, Array2};
@ -20,7 +20,6 @@ pub struct BruteForceSearch {
impl BruteForceSearch { impl BruteForceSearch {
pub fn load(db: &VectorDatabase) -> Result<Self> { pub fn load(db: &VectorDatabase) -> Result<Self> {
// let db = VectorDatabase {};
let documents = db.get_documents()?; let documents = db.get_documents()?;
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
let mut document_ids = vec![]; let mut document_ids = vec![];
@ -63,20 +62,5 @@ impl VectorSearch for BruteForceSearch {
with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
with_indices.truncate(limit); with_indices.truncate(limit);
with_indices with_indices
// // extract the sorted indices from the sorted tuple vector
// let stored_indices = with_indices
// .into_iter()
// .map(|(index, value)| index)
// .collect::<Vec<>>();
// let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
// let mut results = vec![];
// for idx in sorted_indices[0..limit].to_vec() {
// results.push((self.document_ids[idx], 1.0 - similarities[idx]));
// }
// return results;
} }
} }

View file

@ -3,16 +3,19 @@ mod embedding;
mod parsing; mod parsing;
mod search; mod search;
#[cfg(test)]
mod vector_store_tests;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use db::VectorDatabase; use db::{VectorDatabase, VECTOR_DB_URL};
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
use language::LanguageRegistry; use language::LanguageRegistry;
use parsing::Document; use parsing::Document;
use project::{Fs, Project}; use project::{Fs, Project};
use search::{BruteForceSearch, VectorSearch}; use search::{BruteForceSearch, VectorSearch};
use smol::channel; use smol::channel;
use std::{path::PathBuf, sync::Arc, time::Instant}; use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
use tree_sitter::{Parser, QueryCursor}; use tree_sitter::{Parser, QueryCursor};
use util::{http::HttpClient, ResultExt, TryFutureExt}; use util::{http::HttpClient, ResultExt, TryFutureExt};
use workspace::WorkspaceCreated; use workspace::WorkspaceCreated;
@ -23,7 +26,16 @@ pub fn init(
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut AppContext, cx: &mut AppContext,
) { ) {
let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry)); let vector_store = cx.add_model(|cx| {
VectorStore::new(
fs,
VECTOR_DB_URL.to_string(),
Arc::new(OpenAIEmbeddings {
client: http_client,
}),
language_registry,
)
});
cx.subscribe_global::<WorkspaceCreated, _>({ cx.subscribe_global::<WorkspaceCreated, _>({
let vector_store = vector_store.clone(); let vector_store = vector_store.clone();
@ -49,28 +61,36 @@ pub struct IndexedFile {
documents: Vec<Document>, documents: Vec<Document>,
} }
struct SearchResult { // struct SearchResult {
path: PathBuf, // path: PathBuf,
offset: usize, // offset: usize,
name: String, // name: String,
distance: f32, // distance: f32,
} // }
struct VectorStore { struct VectorStore {
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>, database_url: Arc<str>,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
} }
pub struct SearchResult {
pub name: String,
pub offset: usize,
pub file_path: PathBuf,
}
impl VectorStore { impl VectorStore {
fn new( fn new(
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
http_client: Arc<dyn HttpClient>, database_url: String,
embedding_provider: Arc<dyn EmbeddingProvider>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
) -> Self { ) -> Self {
Self { Self {
fs, fs,
http_client, database_url: database_url.into(),
embedding_provider,
language_registry, language_registry,
} }
} }
@ -79,10 +99,12 @@ impl VectorStore {
cursor: &mut QueryCursor, cursor: &mut QueryCursor,
parser: &mut Parser, parser: &mut Parser,
embedding_provider: &dyn EmbeddingProvider, embedding_provider: &dyn EmbeddingProvider,
fs: &Arc<dyn Fs>,
language_registry: &Arc<LanguageRegistry>, language_registry: &Arc<LanguageRegistry>,
file_path: PathBuf, file_path: PathBuf,
content: String,
) -> Result<IndexedFile> { ) -> Result<IndexedFile> {
dbg!(&file_path, &content);
let language = language_registry let language = language_registry
.language_for_file(&file_path, None) .language_for_file(&file_path, None)
.await?; .await?;
@ -97,7 +119,6 @@ impl VectorStore {
.as_ref() .as_ref()
.ok_or_else(|| anyhow!("no outline query"))?; .ok_or_else(|| anyhow!("no outline query"))?;
let content = fs.load(&file_path).await?;
parser.set_language(grammar.ts_language).unwrap(); parser.set_language(grammar.ts_language).unwrap();
let tree = parser let tree = parser
.parse(&content, None) .parse(&content, None)
@ -142,7 +163,11 @@ impl VectorStore {
}); });
} }
fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) { fn add_project(
&mut self,
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let worktree_scans_complete = project let worktree_scans_complete = project
.read(cx) .read(cx)
.worktrees(cx) .worktrees(cx)
@ -151,7 +176,8 @@ impl VectorStore {
let fs = self.fs.clone(); let fs = self.fs.clone();
let language_registry = self.language_registry.clone(); let language_registry = self.language_registry.clone();
let client = self.http_client.clone(); let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
cx.spawn(|_, cx| async move { cx.spawn(|_, cx| async move {
futures::future::join_all(worktree_scans_complete).await; futures::future::join_all(worktree_scans_complete).await;
@ -163,24 +189,47 @@ impl VectorStore {
.collect::<Vec<_>>() .collect::<Vec<_>>()
}); });
let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>(); let db = VectorDatabase::new(&database_url)?;
let worktree_root_paths = worktrees
.iter()
.map(|worktree| worktree.abs_path().clone())
.collect::<Vec<_>>();
let (db, file_hashes) = cx
.background()
.spawn(async move {
let mut hashes = Vec::new();
for worktree_root_path in worktree_root_paths {
let worktree_id =
db.find_or_create_worktree(worktree_root_path.as_ref())?;
hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
}
anyhow::Ok((db, hashes))
})
.await?;
let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>(); let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
cx.background() cx.background()
.spawn(async move { .spawn({
for worktree in worktrees { let fs = fs.clone();
for file in worktree.files(false, 0) { async move {
paths_tx.try_send(worktree.absolutize(&file.path)).unwrap(); for worktree in worktrees.into_iter() {
for file in worktree.files(false, 0) {
let absolute_path = worktree.absolutize(&file.path);
dbg!(&absolute_path);
if let Some(content) = fs.load(&absolute_path).await.log_err() {
dbg!(&content);
paths_tx.try_send((0, absolute_path, content)).unwrap();
}
}
} }
} }
}) })
.detach(); .detach();
cx.background() let db_write_task = cx.background().spawn(
.spawn({ async move {
let client = client.clone();
async move {
// Initialize Database, creates database and tables if not exists // Initialize Database, creates database and tables if not exists
let db = VectorDatabase::new()?;
while let Ok(indexed_file) = indexed_files_rx.recv().await { while let Ok(indexed_file) = indexed_files_rx.recv().await {
db.insert_file(indexed_file).log_err(); db.insert_file(indexed_file).log_err();
} }
@ -188,39 +237,39 @@ impl VectorStore {
// ALL OF THE BELOW IS FOR TESTING, // ALL OF THE BELOW IS FOR TESTING,
// This should be removed as we find and appropriate place for evaluate our search. // This should be removed as we find and appropriate place for evaluate our search.
let embedding_provider = OpenAIEmbeddings{ client }; // let queries = vec![
let queries = vec![ // "compute embeddings for all of the symbols in the codebase, and write them to a database",
"compute embeddings for all of the symbols in the codebase, and write them to a database", // "compute an outline view of all of the symbols in a buffer",
"compute an outline view of all of the symbols in a buffer", // "scan a directory on the file system and load all of its children into an in-memory snapshot",
"scan a directory on the file system and load all of its children into an in-memory snapshot", // ];
]; // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
let t2 = Instant::now(); // let t2 = Instant::now();
let documents = db.get_documents().unwrap(); // let documents = db.get_documents().unwrap();
let files = db.get_files().unwrap(); // let files = db.get_files().unwrap();
println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis()); // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
let t1 = Instant::now(); // let t1 = Instant::now();
let mut bfs = BruteForceSearch::load(&db).unwrap(); // let mut bfs = BruteForceSearch::load(&db).unwrap();
println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis()); // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
for (idx, embed) in embeddings.into_iter().enumerate() { // for (idx, embed) in embeddings.into_iter().enumerate() {
let t0 = Instant::now(); // let t0 = Instant::now();
println!("\nQuery: {:?}", queries[idx]); // println!("\nQuery: {:?}", queries[idx]);
let results = bfs.top_k_search(&embed, 5).await; // let results = bfs.top_k_search(&embed, 5).await;
println!("Search Elapsed: {}", t0.elapsed().as_millis()); // println!("Search Elapsed: {}", t0.elapsed().as_millis());
for (id, distance) in results { // for (id, distance) in results {
println!(""); // println!("");
println!(" distance: {:?}", distance); // println!(" distance: {:?}", distance);
println!(" document: {:?}", documents[&id].name); // println!(" document: {:?}", documents[&id].name);
println!(" path: {:?}", files[&documents[&id].file_id].path); // println!(" path: {:?}", files[&documents[&id].file_id].relative_path);
} // }
} // }
anyhow::Ok(()) anyhow::Ok(())
}}.log_err()) }
.detach(); .log_err(),
);
let provider = DummyEmbeddings {}; let provider = DummyEmbeddings {};
// let provider = OpenAIEmbeddings { client }; // let provider = OpenAIEmbeddings { client };
@ -231,14 +280,15 @@ impl VectorStore {
scope.spawn(async { scope.spawn(async {
let mut parser = Parser::new(); let mut parser = Parser::new();
let mut cursor = QueryCursor::new(); let mut cursor = QueryCursor::new();
while let Ok(file_path) = paths_rx.recv().await { while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
{
if let Some(indexed_file) = Self::index_file( if let Some(indexed_file) = Self::index_file(
&mut cursor, &mut cursor,
&mut parser, &mut parser,
&provider, &provider,
&fs,
&language_registry, &language_registry,
file_path, file_path,
content,
) )
.await .await
.log_err() .log_err()
@ -250,11 +300,86 @@ impl VectorStore {
} }
}) })
.await; .await;
drop(indexed_files_tx);
db_write_task.await;
anyhow::Ok(())
})
}
pub fn search(
&mut self,
phrase: String,
limit: usize,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<SearchResult>>> {
let embedding_provider = self.embedding_provider.clone();
let database_url = self.database_url.clone();
cx.spawn(|this, cx| async move {
let database = VectorDatabase::new(database_url.as_ref())?;
// let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
//
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
database.for_each_document(0, |id, embedding| {
dbg!(id, &embedding);
let similarity = dot(&embedding.0, &embedding.0);
let ix = match results.binary_search_by(|(_, s)| {
s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
dbg!(&results);
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
// let documents = database.get_documents_by_ids(ids)?;
// let search_provider = cx
// .background()
// .spawn(async move { BruteForceSearch::load(&database) })
// .await?;
// let results = search_provider.top_k_search(&embedding, limit))
anyhow::Ok(vec![])
}) })
.detach();
} }
} }
impl Entity for VectorStore { impl Entity for VectorStore {
type Event = (); type Event = ();
} }
fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
let len = vec_a.len();
assert_eq!(len, vec_b.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
vec_a.as_ptr(),
len as isize,
1,
vec_b.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}

View file

@ -0,0 +1,136 @@
use std::sync::Arc;
use crate::{dot, embedding::EmbeddingProvider, VectorStore};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry};
use project::{FakeFs, Project};
use rand::Rng;
use serde_json::json;
use unindent::Unindent;
#[gpui::test]
async fn test_vector_store(cx: &mut TestAppContext) {
let fs = FakeFs::new(cx.background());
fs.insert_tree(
"/the-root",
json!({
"src": {
"file1.rs": "
fn aaa() {
println!(\"aaaa!\");
}
fn zzzzzzzzz() {
println!(\"SLEEPING\");
}
".unindent(),
"file2.rs": "
fn bbb() {
println!(\"bbbb!\");
}
".unindent(),
}
}),
)
.await;
let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
let rust_language = Arc::new(
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".into()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_outline_query(
r#"
(function_item
name: (identifier) @name
body: (block)) @item
"#,
)
.unwrap(),
);
languages.add(rust_language);
let store = cx.add_model(|_| {
VectorStore::new(
fs.clone(),
"foo".to_string(),
Arc::new(FakeEmbeddingProvider),
languages,
)
});
let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
store
.update(cx, |store, cx| store.add_project(project, cx))
.await
.unwrap();
let search_results = store
.update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
.await
.unwrap();
assert_eq!(search_results[0].offset, 0);
assert_eq!(search_results[1].name, "aaa");
}
#[test]
fn test_dot_product() {
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
for _ in 0..100 {
let mut rng = rand::thread_rng();
let a: [f32; 32] = rng.gen();
let b: [f32; 32] = rng.gen();
assert_eq!(
round_to_decimals(dot(&a, &b), 3),
round_to_decimals(reference_dot(&a, &b), 3)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
struct FakeEmbeddingProvider;
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
Ok(spans
.iter()
.map(|span| {
let mut result = vec![0.0; 26];
for letter in span.chars() {
if letter as u32 > 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result
})
.collect())
}
}