move queuing to embedding_queue functionality and update embedding provider to include trait items for max tokens per batch"
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
9781047156
commit
76ce52df4e
5 changed files with 295 additions and 91 deletions
|
@ -1,14 +1,16 @@
|
|||
use crate::{
|
||||
db::dot,
|
||||
embedding::{DummyEmbeddings, EmbeddingProvider},
|
||||
embedding_queue::EmbeddingQueue,
|
||||
parsing::{subtract_ranges, CodeContextRetriever, Document},
|
||||
semantic_index_settings::SemanticIndexSettings,
|
||||
SearchResult, SemanticIndex,
|
||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gpui::{Task, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
|
||||
use rand::{rngs::StdRng, Rng};
|
||||
|
@ -20,8 +22,10 @@ use std::{
|
|||
atomic::{self, AtomicUsize},
|
||||
Arc,
|
||||
},
|
||||
time::SystemTime,
|
||||
};
|
||||
use unindent::Unindent;
|
||||
use util::RandomCharIter;
|
||||
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
|
@ -32,11 +36,7 @@ fn init_logger() {
|
|||
|
||||
#[gpui::test]
|
||||
async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
settings::register::<SemanticIndexSettings>(cx);
|
||||
settings::register::<ProjectSettings>(cx);
|
||||
});
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.background());
|
||||
fs.insert_tree(
|
||||
|
@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
|||
let db_path = db_dir.path().join("db.sqlite");
|
||||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let store = SemanticIndex::new(
|
||||
let semantic_index = SemanticIndex::new(
|
||||
fs.clone(),
|
||||
db_path,
|
||||
embedding_provider.clone(),
|
||||
|
@ -87,13 +87,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
|||
|
||||
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
|
||||
|
||||
let _ = store
|
||||
let _ = semantic_index
|
||||
.update(cx, |store, cx| {
|
||||
store.initialize_project(project.clone(), cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
let (file_count, outstanding_file_count) = store
|
||||
let (file_count, outstanding_file_count) = semantic_index
|
||||
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -101,7 +101,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
|||
cx.foreground().run_until_parked();
|
||||
assert_eq!(*outstanding_file_count.borrow(), 0);
|
||||
|
||||
let search_results = store
|
||||
let search_results = semantic_index
|
||||
.update(cx, |store, cx| {
|
||||
store.search_project(
|
||||
project.clone(),
|
||||
|
@ -129,7 +129,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
|||
// Test Include Files Functonality
|
||||
let include_files = vec![PathMatcher::new("*.rs").unwrap()];
|
||||
let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
|
||||
let rust_only_search_results = store
|
||||
let rust_only_search_results = semantic_index
|
||||
.update(cx, |store, cx| {
|
||||
store.search_project(
|
||||
project.clone(),
|
||||
|
@ -153,7 +153,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
|||
cx,
|
||||
);
|
||||
|
||||
let no_rust_search_results = store
|
||||
let no_rust_search_results = semantic_index
|
||||
.update(cx, |store, cx| {
|
||||
store.search_project(
|
||||
project.clone(),
|
||||
|
@ -189,7 +189,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
|||
cx.foreground().run_until_parked();
|
||||
|
||||
let prev_embedding_count = embedding_provider.embedding_count();
|
||||
let (file_count, outstanding_file_count) = store
|
||||
let (file_count, outstanding_file_count) = semantic_index
|
||||
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -204,6 +204,69 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
let (outstanding_job_count, _) = postage::watch::channel_with(0);
|
||||
let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
|
||||
|
||||
let files = (1..=3)
|
||||
.map(|file_ix| FileToEmbed {
|
||||
worktree_id: 5,
|
||||
path: format!("path-{file_ix}").into(),
|
||||
mtime: SystemTime::now(),
|
||||
documents: (0..rng.gen_range(4..22))
|
||||
.map(|document_ix| {
|
||||
let content_len = rng.gen_range(10..100);
|
||||
Document {
|
||||
range: 0..10,
|
||||
embedding: Vec::new(),
|
||||
name: format!("document {document_ix}"),
|
||||
content: RandomCharIter::new(&mut rng)
|
||||
.with_simple_text()
|
||||
.take(content_len)
|
||||
.collect(),
|
||||
sha1: rng.gen(),
|
||||
token_count: rng.gen_range(10..30),
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
job_handle: JobHandle::new(&outstanding_job_count),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone());
|
||||
|
||||
let finished_files = cx.update(|cx| {
|
||||
for file in &files {
|
||||
queue.push(file.clone(), cx);
|
||||
}
|
||||
queue.flush(cx);
|
||||
queue.finished_files()
|
||||
});
|
||||
|
||||
cx.foreground().run_until_parked();
|
||||
let mut embedded_files: Vec<_> = files
|
||||
.iter()
|
||||
.map(|_| finished_files.try_recv().expect("no finished file"))
|
||||
.collect();
|
||||
|
||||
let expected_files: Vec<_> = files
|
||||
.iter()
|
||||
.map(|file| {
|
||||
let mut file = file.clone();
|
||||
for doc in &mut file.documents {
|
||||
doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
|
||||
}
|
||||
file
|
||||
})
|
||||
.collect();
|
||||
|
||||
embedded_files.sort_by_key(|f| f.path.clone());
|
||||
|
||||
assert_eq!(embedded_files, expected_files);
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn assert_search_results(
|
||||
actual: &[SearchResult],
|
||||
|
@ -1220,47 +1283,42 @@ impl FakeEmbeddingProvider {
|
|||
fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
fn embed_sync(&self, span: &str) -> Vec<f32> {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
fn count_tokens(&self, span: &str) -> usize {
|
||||
span.len()
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
(span.to_string(), 1)
|
||||
}
|
||||
|
||||
fn should_truncate(&self, span: &str) -> bool {
|
||||
false
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
200
|
||||
}
|
||||
|
||||
fn truncate(&self, span: &str) -> String {
|
||||
span.to_string()
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
Ok(spans
|
||||
.iter()
|
||||
.map(|span| {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result
|
||||
})
|
||||
.collect())
|
||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1704,3 +1762,11 @@ fn test_subtract_ranges() {
|
|||
|
||||
assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
settings::register::<SemanticIndexSettings>(cx);
|
||||
settings::register::<ProjectSettings>(cx);
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue