move embedding provider to ai crate

This commit is contained in:
KCaverly 2023-09-22 09:33:59 -04:00
parent 48e151495f
commit 68c37ca2a4
11 changed files with 78 additions and 35 deletions

19
Cargo.lock generated
View file

@ -91,13 +91,25 @@ name = "ai"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"ctor", "async-trait",
"bincode",
"futures 0.3.28", "futures 0.3.28",
"gpui", "gpui",
"isahc", "isahc",
"lazy_static",
"log",
"matrixmultiply",
"ordered-float",
"parking_lot 0.11.2",
"parse_duration",
"postage",
"rand 0.8.5",
"regex", "regex",
"rusqlite",
"serde", "serde",
"serde_json", "serde_json",
"tiktoken-rs 0.5.4",
"util",
] ]
[[package]] [[package]]
@ -6725,9 +6737,9 @@ dependencies = [
name = "semantic_index" name = "semantic_index"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"ai",
"anyhow", "anyhow",
"async-trait", "async-trait",
"bincode",
"client", "client",
"collections", "collections",
"ctor", "ctor",
@ -6736,15 +6748,12 @@ dependencies = [
"futures 0.3.28", "futures 0.3.28",
"globset", "globset",
"gpui", "gpui",
"isahc",
"language", "language",
"lazy_static", "lazy_static",
"log", "log",
"matrixmultiply",
"node_runtime", "node_runtime",
"ordered-float", "ordered-float",
"parking_lot 0.11.2", "parking_lot 0.11.2",
"parse_duration",
"picker", "picker",
"postage", "postage",
"pretty_assertions", "pretty_assertions",

View file

@ -10,12 +10,25 @@ doctest = false
[dependencies] [dependencies]
gpui = { path = "../gpui" } gpui = { path = "../gpui" }
util = { path = "../util" }
async-trait.workspace = true
anyhow.workspace = true anyhow.workspace = true
futures.workspace = true futures.workspace = true
lazy_static.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
isahc.workspace = true isahc.workspace = true
regex.workspace = true regex.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
postage.workspace = true
rand.workspace = true
log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies] [dev-dependencies]
ctor.workspace = true gpui = { path = "../gpui", features = ["test-support"] }

View file

@ -1 +1,2 @@
pub mod completion; pub mod completion;
pub mod embedding;

View file

@ -27,8 +27,30 @@ lazy_static! {
} }
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct Embedding(Vec<f32>); pub struct Embedding(pub Vec<f32>);
// This is needed for semantic index functionality
// Unfortunately it has to live wherever the "Embedding" struct is created.
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
// which is less than ideal
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
impl From<Vec<f32>> for Embedding { impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self { fn from(value: Vec<f32>) -> Self {
Embedding(value) Embedding(value)
@ -63,24 +85,24 @@ impl Embedding {
} }
} }
impl FromSql for Embedding { // impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> { // fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?; // let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes); // let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() { // if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); // return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
} // }
Ok(Embedding(embedding.unwrap())) // Ok(Embedding(embedding.unwrap()))
} // }
} // }
impl ToSql for Embedding { // impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> { // fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0) // let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; // .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) // Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
} // }
} // }
#[derive(Clone)] #[derive(Clone)]
pub struct OpenAIEmbeddings { pub struct OpenAIEmbeddings {

View file

@ -9,6 +9,7 @@ path = "src/semantic_index.rs"
doctest = false doctest = false
[dependencies] [dependencies]
ai = { path = "../ai" }
collections = { path = "../collections" } collections = { path = "../collections" }
gpui = { path = "../gpui" } gpui = { path = "../gpui" }
language = { path = "../language" } language = { path = "../language" }
@ -26,22 +27,18 @@ futures.workspace = true
ordered-float.workspace = true ordered-float.workspace = true
smol.workspace = true smol.workspace = true
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] } rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
isahc.workspace = true
log.workspace = true log.workspace = true
tree-sitter.workspace = true tree-sitter.workspace = true
lazy_static.workspace = true lazy_static.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
async-trait.workspace = true async-trait.workspace = true
bincode = "1.3.3"
matrixmultiply = "0.3.7"
tiktoken-rs = "0.5.0" tiktoken-rs = "0.5.0"
parking_lot.workspace = true parking_lot.workspace = true
rand.workspace = true rand.workspace = true
schemars.workspace = true schemars.workspace = true
globset.workspace = true globset.workspace = true
sha1 = "0.10.5" sha1 = "0.10.5"
parse_duration = "2.1.1"
[dev-dependencies] [dev-dependencies]
collections = { path = "../collections", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] }

View file

@ -1,10 +1,10 @@
use ai::embedding::OpenAIEmbeddings;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use client::{self, UserStore}; use client::{self, UserStore};
use gpui::{AsyncAppContext, ModelHandle, Task}; use gpui::{AsyncAppContext, ModelHandle, Task};
use language::LanguageRegistry; use language::LanguageRegistry;
use node_runtime::RealNodeRuntime; use node_runtime::RealNodeRuntime;
use project::{Project, RealFs}; use project::{Project, RealFs};
use semantic_index::embedding::OpenAIEmbeddings;
use semantic_index::semantic_index_settings::SemanticIndexSettings; use semantic_index::semantic_index_settings::SemanticIndexSettings;
use semantic_index::{SearchResult, SemanticIndex}; use semantic_index::{SearchResult, SemanticIndex};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View file

@ -1,8 +1,8 @@
use crate::{ use crate::{
embedding::Embedding,
parsing::{Span, SpanDigest}, parsing::{Span, SpanDigest},
SEMANTIC_INDEX_VERSION, SEMANTIC_INDEX_VERSION,
}; };
use ai::embedding::Embedding;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use collections::HashMap; use collections::HashMap;
use futures::channel::oneshot; use futures::channel::oneshot;

View file

@ -1,4 +1,5 @@
use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle}; use crate::{parsing::Span, JobHandle};
use ai::embedding::EmbeddingProvider;
use gpui::executor::Background; use gpui::executor::Background;
use parking_lot::Mutex; use parking_lot::Mutex;
use smol::channel; use smol::channel;

View file

@ -1,4 +1,4 @@
use crate::embedding::{Embedding, EmbeddingProvider}; use ai::embedding::{Embedding, EmbeddingProvider};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use language::{Grammar, Language}; use language::{Grammar, Language};
use rusqlite::{ use rusqlite::{

View file

@ -1,5 +1,5 @@
mod db; mod db;
pub mod embedding; // pub mod embedding;
mod embedding_queue; mod embedding_queue;
mod parsing; mod parsing;
pub mod semantic_index_settings; pub mod semantic_index_settings;
@ -11,7 +11,7 @@ use crate::semantic_index_settings::SemanticIndexSettings;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase; use db::VectorDatabase;
use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
use embedding_queue::{EmbeddingQueue, FileToEmbed}; use embedding_queue::{EmbeddingQueue, FileToEmbed};
use futures::{future, FutureExt, StreamExt}; use futures::{future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};

View file

@ -1,10 +1,10 @@
use crate::{ use crate::{
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
embedding_queue::EmbeddingQueue, embedding_queue::EmbeddingQueue,
parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest}, parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
semantic_index_settings::SemanticIndexSettings, semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
}; };
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use gpui::{executor::Deterministic, Task, TestAppContext}; use gpui::{executor::Deterministic, Task, TestAppContext};