fill embeddings with database values and skip during embeddings queue
This commit is contained in:
parent
220533ff1a
commit
50cfb067e7
2 changed files with 48 additions and 21 deletions
|
@ -42,6 +42,7 @@ pub struct EmbeddingQueue {
|
|||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileToEmbedFragment {
|
||||
file: Arc<Mutex<FileToEmbed>>,
|
||||
document_range: Range<usize>,
|
||||
|
@ -74,8 +75,16 @@ impl EmbeddingQueue {
|
|||
});
|
||||
|
||||
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
|
||||
let mut saved_tokens = 0;
|
||||
for (ix, document) in file.lock().documents.iter().enumerate() {
|
||||
let next_token_count = self.pending_batch_token_count + document.token_count;
|
||||
let document_token_count = if document.embedding.is_none() {
|
||||
document.token_count
|
||||
} else {
|
||||
saved_tokens += document.token_count;
|
||||
0
|
||||
};
|
||||
|
||||
let next_token_count = self.pending_batch_token_count + document_token_count;
|
||||
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
|
||||
let range_end = fragment_range.end;
|
||||
self.flush();
|
||||
|
@ -87,8 +96,9 @@ impl EmbeddingQueue {
|
|||
}
|
||||
|
||||
fragment_range.end = ix + 1;
|
||||
self.pending_batch_token_count += document.token_count;
|
||||
self.pending_batch_token_count += document_token_count;
|
||||
}
|
||||
log::trace!("Saved Tokens: {:?}", saved_tokens);
|
||||
}
|
||||
|
||||
pub fn flush(&mut self) {
|
||||
|
@ -100,25 +110,41 @@ impl EmbeddingQueue {
|
|||
|
||||
let finished_files_tx = self.finished_files_tx.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
|
||||
self.executor.spawn(async move {
|
||||
let mut spans = Vec::new();
|
||||
let mut document_count = 0;
|
||||
for fragment in &batch {
|
||||
let file = fragment.file.lock();
|
||||
document_count += file.documents[fragment.document_range.clone()].len();
|
||||
spans.extend(
|
||||
{
|
||||
file.documents[fragment.document_range.clone()]
|
||||
.iter()
|
||||
.iter().filter(|d| d.embedding.is_none())
|
||||
.map(|d| d.content.clone())
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
log::trace!("Documents Length: {:?}", document_count);
|
||||
log::trace!("Span Length: {:?}", spans.clone().len());
|
||||
|
||||
// If spans is 0, just send the fragment to the finished files if its the last one.
|
||||
if spans.len() == 0 {
|
||||
for fragment in batch.clone() {
|
||||
if let Some(file) = Arc::into_inner(fragment.file) {
|
||||
finished_files_tx.try_send(file.into_inner()).unwrap();
|
||||
}
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
match embedding_provider.embed_batch(spans).await {
|
||||
Ok(embeddings) => {
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for fragment in batch {
|
||||
for document in
|
||||
&mut fragment.file.lock().documents[fragment.document_range.clone()]
|
||||
&mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
|
||||
{
|
||||
if let Some(embedding) = embeddings.next() {
|
||||
document.embedding = Some(embedding);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue