use crate::{Embedding, EmbeddingProvider, TextToEmbed}; use anyhow::{anyhow, Context as _, Result}; use client::{proto, Client}; use collections::HashMap; use futures::{future::BoxFuture, FutureExt}; use std::sync::Arc; pub struct CloudEmbeddingProvider { model: String, client: Arc, } impl CloudEmbeddingProvider { pub fn new(client: Arc) -> Self { Self { model: "openai/text-embedding-3-small".into(), client, } } } impl EmbeddingProvider for CloudEmbeddingProvider { fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { // First, fetch any embeddings that are cached based on the requested texts' digests // Then compute any embeddings that are missing. async move { if !self.client.status().borrow().is_connected() { return Err(anyhow!("sign in required")); } let cached_embeddings = self.client.request(proto::GetCachedEmbeddings { model: self.model.clone(), digests: texts .iter() .map(|to_embed| to_embed.digest.to_vec()) .collect(), }); let mut embeddings = cached_embeddings .await .context("failed to fetch cached embeddings via cloud model")? .embeddings .into_iter() .map(|embedding| { let digest: [u8; 32] = embedding .digest .try_into() .map_err(|_| anyhow!("invalid digest for cached embedding"))?; Ok((digest, embedding.dimensions)) }) .collect::>>()?; let compute_embeddings_request = proto::ComputeEmbeddings { model: self.model.clone(), texts: texts .iter() .filter_map(|to_embed| { if embeddings.contains_key(&to_embed.digest) { None } else { Some(to_embed.text.to_string()) } }) .collect(), }; if !compute_embeddings_request.texts.is_empty() { let missing_embeddings = self.client.request(compute_embeddings_request).await?; for embedding in missing_embeddings.embeddings { let digest: [u8; 32] = embedding .digest .try_into() .map_err(|_| anyhow!("invalid digest for cached embedding"))?; embeddings.insert(digest, embedding.dimensions); } } texts .iter() .map(|to_embed| { let embedding = embeddings.get(&to_embed.digest).cloned().with_context(|| { format!("server did not return an embedding for {:?}", to_embed) })?; Ok(Embedding::new(embedding)) }) .collect() } .boxed() } fn batch_size(&self) -> usize { 2048 } }