collab: Remove code for embeddings (#29310)
This PR removes the embeddings-related code from collab and the protocol, as we weren't using it anywhere. Release Notes: - N/A
This commit is contained in:
parent
218496744c
commit
ecc600a68f
8 changed files with 1 additions and 267 deletions
|
@ -1,9 +1,7 @@
|
|||
mod cloud;
|
||||
mod lmstudio;
|
||||
mod ollama;
|
||||
mod open_ai;
|
||||
|
||||
pub use cloud::*;
|
||||
pub use lmstudio::*;
|
||||
pub use ollama::*;
|
||||
pub use open_ai::*;
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use client::{Client, proto};
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt, future::BoxFuture};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct CloudEmbeddingProvider {
|
||||
model: String,
|
||||
client: Arc<Client>,
|
||||
}
|
||||
|
||||
impl CloudEmbeddingProvider {
|
||||
pub fn new(client: Arc<Client>) -> Self {
|
||||
Self {
|
||||
model: "openai/text-embedding-3-small".into(),
|
||||
client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingProvider for CloudEmbeddingProvider {
|
||||
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
|
||||
// First, fetch any embeddings that are cached based on the requested texts' digests
|
||||
// Then compute any embeddings that are missing.
|
||||
async move {
|
||||
if !self.client.status().borrow().is_connected() {
|
||||
return Err(anyhow!("sign in required"));
|
||||
}
|
||||
|
||||
let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
|
||||
model: self.model.clone(),
|
||||
digests: texts
|
||||
.iter()
|
||||
.map(|to_embed| to_embed.digest.to_vec())
|
||||
.collect(),
|
||||
});
|
||||
let mut embeddings = cached_embeddings
|
||||
.await
|
||||
.context("failed to fetch cached embeddings via cloud model")?
|
||||
.embeddings
|
||||
.into_iter()
|
||||
.map(|embedding| {
|
||||
let digest: [u8; 32] = embedding
|
||||
.digest
|
||||
.try_into()
|
||||
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
|
||||
Ok((digest, embedding.dimensions))
|
||||
})
|
||||
.collect::<Result<HashMap<_, _>>>()?;
|
||||
|
||||
let compute_embeddings_request = proto::ComputeEmbeddings {
|
||||
model: self.model.clone(),
|
||||
texts: texts
|
||||
.iter()
|
||||
.filter_map(|to_embed| {
|
||||
if embeddings.contains_key(&to_embed.digest) {
|
||||
None
|
||||
} else {
|
||||
Some(to_embed.text.to_string())
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
if !compute_embeddings_request.texts.is_empty() {
|
||||
let missing_embeddings = self.client.request(compute_embeddings_request).await?;
|
||||
for embedding in missing_embeddings.embeddings {
|
||||
let digest: [u8; 32] = embedding
|
||||
.digest
|
||||
.try_into()
|
||||
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
|
||||
embeddings.insert(digest, embedding.dimensions);
|
||||
}
|
||||
}
|
||||
|
||||
texts
|
||||
.iter()
|
||||
.map(|to_embed| {
|
||||
let embedding =
|
||||
embeddings.get(&to_embed.digest).cloned().with_context(|| {
|
||||
format!("server did not return an embedding for {:?}", to_embed)
|
||||
})?;
|
||||
Ok(Embedding::new(embedding))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn batch_size(&self) -> usize {
|
||||
2048
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue