From c071b271be195b0e8af9335469c969e6f1624d6d Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 23 Jun 2023 10:25:12 -0400 Subject: [PATCH] removed tokio and sqlx dependency, added dummy embeddings provider to save on open ai costs when testing --- Cargo.lock | 2 - crates/vector_store/Cargo.toml | 2 - crates/vector_store/src/db.rs | 74 ++++++++++--------------- crates/vector_store/src/embedding.rs | 12 ++++ crates/vector_store/src/vector_store.rs | 9 +-- 5 files changed, 45 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5a93ce77af..3f13c75dda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7912,7 +7912,6 @@ name = "vector_store" version = "0.1.0" dependencies = [ "anyhow", - "async-compat", "async-trait", "futures 0.3.28", "gpui", @@ -7925,7 +7924,6 @@ dependencies = [ "serde", "serde_json", "smol", - "sqlx", "tree-sitter", "util", "workspace", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 2db672ed25..434f341147 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -17,8 +17,6 @@ util = { path = "../util" } anyhow.workspace = true futures.workspace = true smol.workspace = true -sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] } -async-compat = "0.2.1" rusqlite = "0.27.0" isahc.workspace = true log.workspace = true diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index d335d327b8..e2b23f7548 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,6 +1,5 @@ use anyhow::Result; -use async_compat::{Compat, CompatExt}; -use sqlx::{migrate::MigrateDatabase, Sqlite, SqlitePool}; +use rusqlite::params; use crate::IndexedFile; @@ -13,32 +12,20 @@ pub struct VectorDatabase {} impl VectorDatabase { pub async fn initialize_database() -> Result<()> { - // If database doesnt exist create database - if !Sqlite::database_exists(VECTOR_DB_URL) - .compat() - .await - .unwrap_or(false) - { - Sqlite::create_database(VECTOR_DB_URL).compat().await?; - } - - let db = SqlitePool::connect(VECTOR_DB_URL).compat().await?; + // This will create the database if it doesnt exist + let db = rusqlite::Connection::open(VECTOR_DB_URL)?; // Initialize Vector Databasing Tables - // We may be able to skip this assuming the database is never created - // without creating the tables at the same time. - sqlx::query( + db.execute( "CREATE TABLE IF NOT EXISTS files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - path NVARCHAR(100) NOT NULL, - sha1 NVARCHAR(40) NOT NULL - )", - ) - .execute(&db) - .compat() - .await?; + id INTEGER PRIMARY KEY AUTOINCREMENT, + path NVARCHAR(100) NOT NULL, + sha1 NVARCHAR(40) NOT NULL + )", + [], + )?; - sqlx::query( + db.execute( "CREATE TABLE IF NOT EXISTS documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, file_id INTEGER NOT NULL, @@ -47,26 +34,22 @@ impl VectorDatabase { embedding BLOB NOT NULL, FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE )", - ) - .execute(&db) - .compat() - .await?; + [], + )?; Ok(()) } pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> { // Write to files table, and return generated id. - let db = SqlitePool::connect(VECTOR_DB_URL).compat().await?; + let db = rusqlite::Connection::open(VECTOR_DB_URL)?; - let files_insert = sqlx::query("INSERT INTO files (path, sha1) VALUES ($1, $2)") - .bind(indexed_file.path.to_str()) - .bind(indexed_file.sha1) - .execute(&db) - .compat() - .await?; + let files_insert = db.execute( + "INSERT INTO files (path, sha1) VALUES (?1, ?2)", + params![indexed_file.path.to_str(), indexed_file.sha1], + )?; - let inserted_id = files_insert.last_insert_rowid(); + let inserted_id = db.last_insert_rowid(); // I stole this from https://stackoverflow.com/questions/71829931/how-do-i-convert-a-negative-f32-value-to-binary-string-and-back-again // I imagine there is a better way to serialize to/from blob @@ -88,16 +71,15 @@ impl VectorDatabase { // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. for document in indexed_file.documents { - sqlx::query( - "INSERT INTO documents (file_id, offset, name, embedding) VALUES ($1, $2, $3, $4)", - ) - .bind(inserted_id) - .bind(document.offset.to_string()) - .bind(document.name) - .bind(get_binary_from_values(document.embedding)) - .execute(&db) - .compat() - .await?; + db.execute( + "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", + params![ + inserted_id, + document.offset.to_string(), + document.name, + get_binary_from_values(document.embedding) + ], + )?; } Ok(()) diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index f1ae5479ee..4883917d5a 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -47,6 +47,18 @@ pub trait EmbeddingProvider: Sync { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; } +pub struct DummyEmbeddings {} + +#[async_trait] +impl EmbeddingProvider for DummyEmbeddings { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + // 1024 is the OpenAI Embeddings size for ada models. + // the model we will likely be starting with. + let dummy_vec = vec![0.32 as f32; 1024]; + return Ok(vec![dummy_vec; spans.len()]); + } +} + #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index f4d5baca80..f424346d56 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -3,7 +3,7 @@ mod embedding; use anyhow::{anyhow, Result}; use db::VectorDatabase; -use embedding::{EmbeddingProvider, OpenAIEmbeddings}; +use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use language::LanguageRegistry; use project::{Fs, Project}; @@ -38,14 +38,14 @@ pub fn init( .detach(); } -#[derive(Debug, sqlx::FromRow)] +#[derive(Debug)] struct Document { offset: usize, name: String, embedding: Vec, } -#[derive(Debug, sqlx::FromRow)] +#[derive(Debug)] pub struct IndexedFile { path: PathBuf, sha1: String, @@ -188,7 +188,8 @@ impl VectorStore { }) .detach(); - let provider = OpenAIEmbeddings { client }; + // let provider = OpenAIEmbeddings { client }; + let provider = DummyEmbeddings {}; let t0 = Instant::now();