Reify Embedding/Sha1 structs that can be (de)serialized from SQL

Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
Antonio Scandurra 2023-08-31 17:55:43 +02:00
parent c763e728d1
commit 3001a46f69
5 changed files with 180 additions and 138 deletions

View file

@ -1,8 +1,7 @@
use crate::{
db::dot,
embedding::{DummyEmbeddings, EmbeddingProvider},
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
embedding_queue::EmbeddingQueue,
parsing::{subtract_ranges, CodeContextRetriever, Document},
parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1},
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex,
};
@ -217,15 +216,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
documents: (0..rng.gen_range(4..22))
.map(|document_ix| {
let content_len = rng.gen_range(10..100);
let content = RandomCharIter::new(&mut rng)
.with_simple_text()
.take(content_len)
.collect::<String>();
let sha1 = Sha1::from(content.as_str());
Document {
range: 0..10,
embedding: Vec::new(),
embedding: None,
name: format!("document {document_ix}"),
content: RandomCharIter::new(&mut rng)
.with_simple_text()
.take(content_len)
.collect(),
sha1: rng.gen(),
content,
sha1,
token_count: rng.gen_range(10..30),
}
})
@ -254,7 +255,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
.map(|file| {
let mut file = file.clone();
for doc in &mut file.documents {
doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
}
file
})
@ -1242,36 +1243,6 @@ async fn test_code_context_retrieval_php() {
);
}
#[gpui::test]
fn test_dot_product(mut rng: StdRng) {
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
assert_eq!(
round_to_decimals(dot(&a, &b), 1),
round_to_decimals(reference_dot(&a, &b), 1)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
@ -1282,7 +1253,7 @@ impl FakeEmbeddingProvider {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
fn embed_sync(&self, span: &str) -> Vec<f32> {
fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
@ -1299,7 +1270,7 @@ impl FakeEmbeddingProvider {
*x /= norm;
}
result
result.into()
}
}
@ -1313,7 +1284,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
200
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())