Reify Embedding
/Sha1
structs that can be (de)serialized from SQL
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
This commit is contained in:
parent
c763e728d1
commit
3001a46f69
5 changed files with 180 additions and 138 deletions
|
@ -1,8 +1,7 @@
|
|||
use crate::{
|
||||
db::dot,
|
||||
embedding::{DummyEmbeddings, EmbeddingProvider},
|
||||
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
|
||||
embedding_queue::EmbeddingQueue,
|
||||
parsing::{subtract_ranges, CodeContextRetriever, Document},
|
||||
parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1},
|
||||
semantic_index_settings::SemanticIndexSettings,
|
||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex,
|
||||
};
|
||||
|
@ -217,15 +216,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
|||
documents: (0..rng.gen_range(4..22))
|
||||
.map(|document_ix| {
|
||||
let content_len = rng.gen_range(10..100);
|
||||
let content = RandomCharIter::new(&mut rng)
|
||||
.with_simple_text()
|
||||
.take(content_len)
|
||||
.collect::<String>();
|
||||
let sha1 = Sha1::from(content.as_str());
|
||||
Document {
|
||||
range: 0..10,
|
||||
embedding: Vec::new(),
|
||||
embedding: None,
|
||||
name: format!("document {document_ix}"),
|
||||
content: RandomCharIter::new(&mut rng)
|
||||
.with_simple_text()
|
||||
.take(content_len)
|
||||
.collect(),
|
||||
sha1: rng.gen(),
|
||||
content,
|
||||
sha1,
|
||||
token_count: rng.gen_range(10..30),
|
||||
}
|
||||
})
|
||||
|
@ -254,7 +255,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
|||
.map(|file| {
|
||||
let mut file = file.clone();
|
||||
for doc in &mut file.documents {
|
||||
doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
|
||||
doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
|
||||
}
|
||||
file
|
||||
})
|
||||
|
@ -1242,36 +1243,6 @@ async fn test_code_context_retrieval_php() {
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_dot_product(mut rng: StdRng) {
|
||||
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
|
||||
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
|
||||
|
||||
for _ in 0..100 {
|
||||
let size = 1536;
|
||||
let mut a = vec![0.; size];
|
||||
let mut b = vec![0.; size];
|
||||
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||
*a = rng.gen();
|
||||
*b = rng.gen();
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
round_to_decimals(dot(&a, &b), 1),
|
||||
round_to_decimals(reference_dot(&a, &b), 1)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize,
|
||||
|
@ -1282,7 +1253,7 @@ impl FakeEmbeddingProvider {
|
|||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
fn embed_sync(&self, span: &str) -> Vec<f32> {
|
||||
fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
|
@ -1299,7 +1270,7 @@ impl FakeEmbeddingProvider {
|
|||
*x /= norm;
|
||||
}
|
||||
|
||||
result
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1313,7 +1284,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
|
|||
200
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue