move queuing to embedding_queue functionality and update embedding provider to include trait items for max tokens per batch"
Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
9781047156
commit
76ce52df4e
5 changed files with 295 additions and 91 deletions
|
@ -53,36 +53,30 @@ struct OpenAIEmbeddingUsage {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait EmbeddingProvider: Sync + Send {
|
pub trait EmbeddingProvider: Sync + Send {
|
||||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
|
||||||
fn count_tokens(&self, span: &str) -> usize;
|
fn max_tokens_per_batch(&self) -> usize;
|
||||||
fn should_truncate(&self, span: &str) -> bool;
|
fn truncate(&self, span: &str) -> (String, usize);
|
||||||
fn truncate(&self, span: &str) -> String;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct DummyEmbeddings {}
|
pub struct DummyEmbeddings {}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for DummyEmbeddings {
|
impl EmbeddingProvider for DummyEmbeddings {
|
||||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||||
// 1024 is the OpenAI Embeddings size for ada models.
|
// 1024 is the OpenAI Embeddings size for ada models.
|
||||||
// the model we will likely be starting with.
|
// the model we will likely be starting with.
|
||||||
let dummy_vec = vec![0.32 as f32; 1536];
|
let dummy_vec = vec![0.32 as f32; 1536];
|
||||||
return Ok(vec![dummy_vec; spans.len()]);
|
return Ok(vec![dummy_vec; spans.len()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn count_tokens(&self, span: &str) -> usize {
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
// For Dummy Providers, we are going to use OpenAI tokenization for ease
|
OPENAI_INPUT_LIMIT
|
||||||
let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
|
||||||
tokens.len()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_truncate(&self, span: &str) -> bool {
|
fn truncate(&self, span: &str) -> (String, usize) {
|
||||||
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate(&self, span: &str) -> String {
|
|
||||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
let token_count = tokens.len();
|
||||||
|
let output = if token_count > OPENAI_INPUT_LIMIT {
|
||||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||||
OPENAI_BPE_TOKENIZER
|
OPENAI_BPE_TOKENIZER
|
||||||
.decode(tokens)
|
.decode(tokens)
|
||||||
|
@ -92,7 +86,7 @@ impl EmbeddingProvider for DummyEmbeddings {
|
||||||
span.to_string()
|
span.to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
output
|
(output, token_count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,19 +119,14 @@ impl OpenAIEmbeddings {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
fn count_tokens(&self, span: &str) -> usize {
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
// For Dummy Providers, we are going to use OpenAI tokenization for ease
|
OPENAI_INPUT_LIMIT
|
||||||
let tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
|
||||||
tokens.len()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_truncate(&self, span: &str) -> bool {
|
fn truncate(&self, span: &str) -> (String, usize) {
|
||||||
self.count_tokens(span) > OPENAI_INPUT_LIMIT
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate(&self, span: &str) -> String {
|
|
||||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
let token_count = tokens.len();
|
||||||
|
let output = if token_count > OPENAI_INPUT_LIMIT {
|
||||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||||
OPENAI_BPE_TOKENIZER
|
OPENAI_BPE_TOKENIZER
|
||||||
.decode(tokens)
|
.decode(tokens)
|
||||||
|
@ -147,10 +136,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
span.to_string()
|
span.to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
output
|
(output, token_count)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
const MAX_RETRIES: usize = 4;
|
const MAX_RETRIES: usize = 4;
|
||||||
|
|
||||||
|
@ -160,9 +149,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||||
|
|
||||||
let mut request_number = 0;
|
let mut request_number = 0;
|
||||||
let mut request_timeout: u64 = 10;
|
let mut request_timeout: u64 = 10;
|
||||||
let mut truncated = false;
|
|
||||||
let mut response: Response<AsyncBody>;
|
let mut response: Response<AsyncBody>;
|
||||||
let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
|
|
||||||
while request_number < MAX_RETRIES {
|
while request_number < MAX_RETRIES {
|
||||||
response = self
|
response = self
|
||||||
.send_request(
|
.send_request(
|
||||||
|
|
140
crates/semantic_index/src/embedding_queue.rs
Normal file
140
crates/semantic_index/src/embedding_queue.rs
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
|
||||||
|
|
||||||
|
use gpui::AppContext;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use smol::channel;
|
||||||
|
|
||||||
|
use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FileToEmbed {
|
||||||
|
pub worktree_id: i64,
|
||||||
|
pub path: PathBuf,
|
||||||
|
pub mtime: SystemTime,
|
||||||
|
pub documents: Vec<Document>,
|
||||||
|
pub job_handle: JobHandle,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for FileToEmbed {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("FileToEmbed")
|
||||||
|
.field("worktree_id", &self.worktree_id)
|
||||||
|
.field("path", &self.path)
|
||||||
|
.field("mtime", &self.mtime)
|
||||||
|
.field("document", &self.documents)
|
||||||
|
.finish_non_exhaustive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for FileToEmbed {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.worktree_id == other.worktree_id
|
||||||
|
&& self.path == other.path
|
||||||
|
&& self.mtime == other.mtime
|
||||||
|
&& self.documents == other.documents
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EmbeddingQueue {
|
||||||
|
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||||
|
pending_batch: Vec<FileToEmbedFragment>,
|
||||||
|
pending_batch_token_count: usize,
|
||||||
|
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||||
|
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FileToEmbedFragment {
|
||||||
|
file: Arc<Mutex<FileToEmbed>>,
|
||||||
|
document_range: Range<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingQueue {
|
||||||
|
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
|
||||||
|
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||||
|
Self {
|
||||||
|
embedding_provider,
|
||||||
|
pending_batch: Vec::new(),
|
||||||
|
pending_batch_token_count: 0,
|
||||||
|
finished_files_tx,
|
||||||
|
finished_files_rx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) {
|
||||||
|
let file = Arc::new(Mutex::new(file));
|
||||||
|
|
||||||
|
self.pending_batch.push(FileToEmbedFragment {
|
||||||
|
file: file.clone(),
|
||||||
|
document_range: 0..0,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
|
||||||
|
for (ix, document) in file.lock().documents.iter().enumerate() {
|
||||||
|
let next_token_count = self.pending_batch_token_count + document.token_count;
|
||||||
|
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
|
||||||
|
let range_end = fragment_range.end;
|
||||||
|
self.flush(cx);
|
||||||
|
self.pending_batch.push(FileToEmbedFragment {
|
||||||
|
file: file.clone(),
|
||||||
|
document_range: range_end..range_end,
|
||||||
|
});
|
||||||
|
fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment_range.end = ix + 1;
|
||||||
|
self.pending_batch_token_count += document.token_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flush(&mut self, cx: &mut AppContext) {
|
||||||
|
let batch = mem::take(&mut self.pending_batch);
|
||||||
|
self.pending_batch_token_count = 0;
|
||||||
|
if batch.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let finished_files_tx = self.finished_files_tx.clone();
|
||||||
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
|
cx.background().spawn(async move {
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
for fragment in &batch {
|
||||||
|
let file = fragment.file.lock();
|
||||||
|
spans.extend(
|
||||||
|
file.documents[fragment.document_range.clone()]
|
||||||
|
.iter()
|
||||||
|
.map(|d| d.content.clone()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
match embedding_provider.embed_batch(spans).await {
|
||||||
|
Ok(embeddings) => {
|
||||||
|
let mut embeddings = embeddings.into_iter();
|
||||||
|
for fragment in batch {
|
||||||
|
for document in
|
||||||
|
&mut fragment.file.lock().documents[fragment.document_range.clone()]
|
||||||
|
{
|
||||||
|
if let Some(embedding) = embeddings.next() {
|
||||||
|
document.embedding = embedding;
|
||||||
|
} else {
|
||||||
|
//
|
||||||
|
log::error!("number of embeddings returned different from number of documents");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(file) = Arc::into_inner(fragment.file) {
|
||||||
|
finished_files_tx.try_send(file.into_inner()).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
log::error!("{:?}", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
|
||||||
|
self.finished_files_rx.clone()
|
||||||
|
}
|
||||||
|
}
|
|
@ -72,8 +72,7 @@ impl CodeContextRetriever {
|
||||||
let mut sha1 = Sha1::new();
|
let mut sha1 = Sha1::new();
|
||||||
sha1.update(&document_span);
|
sha1.update(&document_span);
|
||||||
|
|
||||||
let token_count = self.embedding_provider.count_tokens(&document_span);
|
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||||
let document_span = self.embedding_provider.truncate(&document_span);
|
|
||||||
|
|
||||||
Ok(vec![Document {
|
Ok(vec![Document {
|
||||||
range: 0..content.len(),
|
range: 0..content.len(),
|
||||||
|
@ -93,8 +92,7 @@ impl CodeContextRetriever {
|
||||||
let mut sha1 = Sha1::new();
|
let mut sha1 = Sha1::new();
|
||||||
sha1.update(&document_span);
|
sha1.update(&document_span);
|
||||||
|
|
||||||
let token_count = self.embedding_provider.count_tokens(&document_span);
|
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||||
let document_span = self.embedding_provider.truncate(&document_span);
|
|
||||||
|
|
||||||
Ok(vec![Document {
|
Ok(vec![Document {
|
||||||
range: 0..content.len(),
|
range: 0..content.len(),
|
||||||
|
@ -183,8 +181,8 @@ impl CodeContextRetriever {
|
||||||
.replace("<language>", language_name.as_ref())
|
.replace("<language>", language_name.as_ref())
|
||||||
.replace("item", &document.content);
|
.replace("item", &document.content);
|
||||||
|
|
||||||
let token_count = self.embedding_provider.count_tokens(&document_content);
|
let (document_content, token_count) =
|
||||||
let document_content = self.embedding_provider.truncate(&document_content);
|
self.embedding_provider.truncate(&document_content);
|
||||||
|
|
||||||
document.content = document_content;
|
document.content = document_content;
|
||||||
document.token_count = token_count;
|
document.token_count = token_count;
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
db::dot,
|
db::dot,
|
||||||
embedding::{DummyEmbeddings, EmbeddingProvider},
|
embedding::{DummyEmbeddings, EmbeddingProvider},
|
||||||
|
embedding_queue::EmbeddingQueue,
|
||||||
parsing::{subtract_ranges, CodeContextRetriever, Document},
|
parsing::{subtract_ranges, CodeContextRetriever, Document},
|
||||||
semantic_index_settings::SemanticIndexSettings,
|
semantic_index_settings::SemanticIndexSettings,
|
||||||
SearchResult, SemanticIndex,
|
FileToEmbed, JobHandle, SearchResult, SemanticIndex,
|
||||||
};
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use gpui::{Task, TestAppContext};
|
use gpui::{Task, TestAppContext};
|
||||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
|
use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
|
||||||
use rand::{rngs::StdRng, Rng};
|
use rand::{rngs::StdRng, Rng};
|
||||||
|
@ -20,8 +22,10 @@ use std::{
|
||||||
atomic::{self, AtomicUsize},
|
atomic::{self, AtomicUsize},
|
||||||
Arc,
|
Arc,
|
||||||
},
|
},
|
||||||
|
time::SystemTime,
|
||||||
};
|
};
|
||||||
use unindent::Unindent;
|
use unindent::Unindent;
|
||||||
|
use util::RandomCharIter;
|
||||||
|
|
||||||
#[ctor::ctor]
|
#[ctor::ctor]
|
||||||
fn init_logger() {
|
fn init_logger() {
|
||||||
|
@ -32,11 +36,7 @@ fn init_logger() {
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_semantic_index(cx: &mut TestAppContext) {
|
async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||||
cx.update(|cx| {
|
init_test(cx);
|
||||||
cx.set_global(SettingsStore::test(cx));
|
|
||||||
settings::register::<SemanticIndexSettings>(cx);
|
|
||||||
settings::register::<ProjectSettings>(cx);
|
|
||||||
});
|
|
||||||
|
|
||||||
let fs = FakeFs::new(cx.background());
|
let fs = FakeFs::new(cx.background());
|
||||||
fs.insert_tree(
|
fs.insert_tree(
|
||||||
|
@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||||
let db_path = db_dir.path().join("db.sqlite");
|
let db_path = db_dir.path().join("db.sqlite");
|
||||||
|
|
||||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let store = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs.clone(),
|
fs.clone(),
|
||||||
db_path,
|
db_path,
|
||||||
embedding_provider.clone(),
|
embedding_provider.clone(),
|
||||||
|
@ -87,13 +87,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||||
|
|
||||||
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
|
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
|
||||||
|
|
||||||
let _ = store
|
let _ = semantic_index
|
||||||
.update(cx, |store, cx| {
|
.update(cx, |store, cx| {
|
||||||
store.initialize_project(project.clone(), cx)
|
store.initialize_project(project.clone(), cx)
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let (file_count, outstanding_file_count) = store
|
let (file_count, outstanding_file_count) = semantic_index
|
||||||
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -101,7 +101,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||||
cx.foreground().run_until_parked();
|
cx.foreground().run_until_parked();
|
||||||
assert_eq!(*outstanding_file_count.borrow(), 0);
|
assert_eq!(*outstanding_file_count.borrow(), 0);
|
||||||
|
|
||||||
let search_results = store
|
let search_results = semantic_index
|
||||||
.update(cx, |store, cx| {
|
.update(cx, |store, cx| {
|
||||||
store.search_project(
|
store.search_project(
|
||||||
project.clone(),
|
project.clone(),
|
||||||
|
@ -129,7 +129,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||||
// Test Include Files Functonality
|
// Test Include Files Functonality
|
||||||
let include_files = vec![PathMatcher::new("*.rs").unwrap()];
|
let include_files = vec![PathMatcher::new("*.rs").unwrap()];
|
||||||
let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
|
let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
|
||||||
let rust_only_search_results = store
|
let rust_only_search_results = semantic_index
|
||||||
.update(cx, |store, cx| {
|
.update(cx, |store, cx| {
|
||||||
store.search_project(
|
store.search_project(
|
||||||
project.clone(),
|
project.clone(),
|
||||||
|
@ -153,7 +153,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||||
cx,
|
cx,
|
||||||
);
|
);
|
||||||
|
|
||||||
let no_rust_search_results = store
|
let no_rust_search_results = semantic_index
|
||||||
.update(cx, |store, cx| {
|
.update(cx, |store, cx| {
|
||||||
store.search_project(
|
store.search_project(
|
||||||
project.clone(),
|
project.clone(),
|
||||||
|
@ -189,7 +189,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||||
cx.foreground().run_until_parked();
|
cx.foreground().run_until_parked();
|
||||||
|
|
||||||
let prev_embedding_count = embedding_provider.embedding_count();
|
let prev_embedding_count = embedding_provider.embedding_count();
|
||||||
let (file_count, outstanding_file_count) = store
|
let (file_count, outstanding_file_count) = semantic_index
|
||||||
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -204,6 +204,69 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test(iterations = 10)]
|
||||||
|
async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||||
|
let (outstanding_job_count, _) = postage::watch::channel_with(0);
|
||||||
|
let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
|
||||||
|
|
||||||
|
let files = (1..=3)
|
||||||
|
.map(|file_ix| FileToEmbed {
|
||||||
|
worktree_id: 5,
|
||||||
|
path: format!("path-{file_ix}").into(),
|
||||||
|
mtime: SystemTime::now(),
|
||||||
|
documents: (0..rng.gen_range(4..22))
|
||||||
|
.map(|document_ix| {
|
||||||
|
let content_len = rng.gen_range(10..100);
|
||||||
|
Document {
|
||||||
|
range: 0..10,
|
||||||
|
embedding: Vec::new(),
|
||||||
|
name: format!("document {document_ix}"),
|
||||||
|
content: RandomCharIter::new(&mut rng)
|
||||||
|
.with_simple_text()
|
||||||
|
.take(content_len)
|
||||||
|
.collect(),
|
||||||
|
sha1: rng.gen(),
|
||||||
|
token_count: rng.gen_range(10..30),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
job_handle: JobHandle::new(&outstanding_job_count),
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
|
let mut queue = EmbeddingQueue::new(embedding_provider.clone());
|
||||||
|
|
||||||
|
let finished_files = cx.update(|cx| {
|
||||||
|
for file in &files {
|
||||||
|
queue.push(file.clone(), cx);
|
||||||
|
}
|
||||||
|
queue.flush(cx);
|
||||||
|
queue.finished_files()
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.foreground().run_until_parked();
|
||||||
|
let mut embedded_files: Vec<_> = files
|
||||||
|
.iter()
|
||||||
|
.map(|_| finished_files.try_recv().expect("no finished file"))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let expected_files: Vec<_> = files
|
||||||
|
.iter()
|
||||||
|
.map(|file| {
|
||||||
|
let mut file = file.clone();
|
||||||
|
for doc in &mut file.documents {
|
||||||
|
doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
|
||||||
|
}
|
||||||
|
file
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
embedded_files.sort_by_key(|f| f.path.clone());
|
||||||
|
|
||||||
|
assert_eq!(embedded_files, expected_files);
|
||||||
|
}
|
||||||
|
|
||||||
#[track_caller]
|
#[track_caller]
|
||||||
fn assert_search_results(
|
fn assert_search_results(
|
||||||
actual: &[SearchResult],
|
actual: &[SearchResult],
|
||||||
|
@ -1220,47 +1283,42 @@ impl FakeEmbeddingProvider {
|
||||||
fn embedding_count(&self) -> usize {
|
fn embedding_count(&self) -> usize {
|
||||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn embed_sync(&self, span: &str) -> Vec<f32> {
|
||||||
|
let mut result = vec![1.0; 26];
|
||||||
|
for letter in span.chars() {
|
||||||
|
let letter = letter.to_ascii_lowercase();
|
||||||
|
if letter as u32 >= 'a' as u32 {
|
||||||
|
let ix = (letter as u32) - ('a' as u32);
|
||||||
|
if ix < 26 {
|
||||||
|
result[ix as usize] += 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
for x in &mut result {
|
||||||
|
*x /= norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
fn count_tokens(&self, span: &str) -> usize {
|
fn truncate(&self, span: &str) -> (String, usize) {
|
||||||
span.len()
|
(span.to_string(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_truncate(&self, span: &str) -> bool {
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
false
|
200
|
||||||
}
|
}
|
||||||
|
|
||||||
fn truncate(&self, span: &str) -> String {
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||||
span.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
|
||||||
self.embedding_count
|
self.embedding_count
|
||||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||||
Ok(spans
|
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||||
.iter()
|
|
||||||
.map(|span| {
|
|
||||||
let mut result = vec![1.0; 26];
|
|
||||||
for letter in span.chars() {
|
|
||||||
let letter = letter.to_ascii_lowercase();
|
|
||||||
if letter as u32 >= 'a' as u32 {
|
|
||||||
let ix = (letter as u32) - ('a' as u32);
|
|
||||||
if ix < 26 {
|
|
||||||
result[ix as usize] += 1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
||||||
for x in &mut result {
|
|
||||||
*x /= norm;
|
|
||||||
}
|
|
||||||
|
|
||||||
result
|
|
||||||
})
|
|
||||||
.collect())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1704,3 +1762,11 @@ fn test_subtract_ranges() {
|
||||||
|
|
||||||
assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
|
assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
|
cx.update(|cx| {
|
||||||
|
cx.set_global(SettingsStore::test(cx));
|
||||||
|
settings::register::<SemanticIndexSettings>(cx);
|
||||||
|
settings::register::<ProjectSettings>(cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
|
@ -260,11 +260,22 @@ pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
|
||||||
Defer(Some(f))
|
Defer(Some(f))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RandomCharIter<T: Rng>(T);
|
pub struct RandomCharIter<T: Rng> {
|
||||||
|
rng: T,
|
||||||
|
simple_text: bool,
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: Rng> RandomCharIter<T> {
|
impl<T: Rng> RandomCharIter<T> {
|
||||||
pub fn new(rng: T) -> Self {
|
pub fn new(rng: T) -> Self {
|
||||||
Self(rng)
|
Self {
|
||||||
|
rng,
|
||||||
|
simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_simple_text(mut self) -> Self {
|
||||||
|
self.simple_text = true;
|
||||||
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -272,25 +283,27 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
|
||||||
type Item = char;
|
type Item = char;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) {
|
if self.simple_text {
|
||||||
return if self.0.gen_range(0..100) < 5 {
|
return if self.rng.gen_range(0..100) < 5 {
|
||||||
Some('\n')
|
Some('\n')
|
||||||
} else {
|
} else {
|
||||||
Some(self.0.gen_range(b'a'..b'z' + 1).into())
|
Some(self.rng.gen_range(b'a'..b'z' + 1).into())
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
match self.0.gen_range(0..100) {
|
match self.rng.gen_range(0..100) {
|
||||||
// whitespace
|
// whitespace
|
||||||
0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(),
|
0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
|
||||||
// two-byte greek letters
|
// two-byte greek letters
|
||||||
20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))),
|
20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
|
||||||
// // three-byte characters
|
// // three-byte characters
|
||||||
33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(),
|
33..=45 => ['✋', '✅', '❌', '❎', '⭐']
|
||||||
|
.choose(&mut self.rng)
|
||||||
|
.copied(),
|
||||||
// // four-byte characters
|
// // four-byte characters
|
||||||
46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(),
|
46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
|
||||||
// ascii letters
|
// ascii letters
|
||||||
_ => Some(self.0.gen_range(b'a'..b'z' + 1).into()),
|
_ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue