fill embeddings with database values and skip during embeddings queue

This commit is contained in:
KCaverly 2023-08-31 13:19:17 -04:00
parent 220533ff1a
commit 50cfb067e7
2 changed files with 48 additions and 21 deletions

View file

@ -42,6 +42,7 @@ pub struct EmbeddingQueue {
finished_files_rx: channel::Receiver<FileToEmbed>,
}
#[derive(Clone)]
pub struct FileToEmbedFragment {
file: Arc<Mutex<FileToEmbed>>,
document_range: Range<usize>,
@ -74,8 +75,16 @@ impl EmbeddingQueue {
});
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
let mut saved_tokens = 0;
for (ix, document) in file.lock().documents.iter().enumerate() {
let next_token_count = self.pending_batch_token_count + document.token_count;
let document_token_count = if document.embedding.is_none() {
document.token_count
} else {
saved_tokens += document.token_count;
0
};
let next_token_count = self.pending_batch_token_count + document_token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
self.flush();
@ -87,8 +96,9 @@ impl EmbeddingQueue {
}
fragment_range.end = ix + 1;
self.pending_batch_token_count += document.token_count;
self.pending_batch_token_count += document_token_count;
}
log::trace!("Saved Tokens: {:?}", saved_tokens);
}
pub fn flush(&mut self) {
@ -100,25 +110,41 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
self.executor.spawn(async move {
let mut spans = Vec::new();
let mut document_count = 0;
for fragment in &batch {
let file = fragment.file.lock();
document_count += file.documents[fragment.document_range.clone()].len();
spans.extend(
{
file.documents[fragment.document_range.clone()]
.iter()
.iter().filter(|d| d.embedding.is_none())
.map(|d| d.content.clone())
}
);
}
log::trace!("Documents Length: {:?}", document_count);
log::trace!("Span Length: {:?}", spans.clone().len());
// If spans is 0, just send the fragment to the finished files if its the last one.
if spans.len() == 0 {
for fragment in batch.clone() {
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
return;
};
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for document in
&mut fragment.file.lock().documents[fragment.document_range.clone()]
&mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
document.embedding = Some(embedding);