This commit is contained in:
Antonio Scandurra 2023-09-07 15:25:23 +02:00
parent 757a285852
commit a45c8c380f

View file

@ -108,54 +108,55 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone(); let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone(); let embedding_provider = self.embedding_provider.clone();
self.executor.spawn(async move { self.executor
let mut spans = Vec::new(); .spawn(async move {
for fragment in &batch { let mut spans = Vec::new();
let file = fragment.file.lock(); for fragment in &batch {
spans.extend( let file = fragment.file.lock();
{ spans.extend(
file.spans[fragment.span_range.clone()] file.spans[fragment.span_range.clone()]
.iter().filter(|d| d.embedding.is_none()) .iter()
.map(|d| d.content.clone()) .filter(|d| d.embedding.is_none())
} .map(|d| d.content.clone()),
); );
}
// If spans is 0, just send the fragment to the finished files if its the last one.
if spans.len() == 0 {
for fragment in batch.clone() {
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
} }
return;
};
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for span in
&mut fragment.file.lock().spans[fragment.span_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
span.embedding = Some(embedding);
} else {
log::error!("number of embeddings returned different from number of documents");
}
}
// If spans is 0, just send the fragment to the finished files if its the last one.
if spans.is_empty() {
for fragment in batch.clone() {
if let Some(file) = Arc::into_inner(fragment.file) { if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap(); finished_files_tx.try_send(file.into_inner()).unwrap();
} }
} }
return;
};
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
.iter_mut()
.filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
span.embedding = Some(embedding);
} else {
log::error!("number of embeddings != number of documents");
}
}
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
}
Err(error) => {
log::error!("{:?}", error);
}
} }
Err(error) => { })
log::error!("{:?}", error); .detach();
}
}
})
.detach();
} }
pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> { pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {