collab: Remove code for embeddings (#29310)

This PR removes the embeddings-related code from collab and the
protocol, as we weren't using it anywhere.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-23 18:27:46 -04:00 committed by GitHub
parent 218496744c
commit ecc600a68f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1 additions and 267 deletions

1
Cargo.lock generated
View file

@ -3017,7 +3017,6 @@ dependencies = [
"nanoid",
"node_runtime",
"notifications",
"open_ai",
"parking_lot",
"pretty_assertions",
"project",

View file

@ -41,7 +41,6 @@ jsonwebtoken.workspace = true
livekit_api.workspace = true
log.workspace = true
nanoid.workspace = true
open_ai.workspace = true
parking_lot.workspace = true
prometheus = "0.14"
prost.workspace = true

View file

@ -34,10 +34,8 @@ use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use http_client::HttpClient;
use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
use reqwest_client::ReqwestClient;
use rpc::proto::split_repository_update;
use sha2::Digest;
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
use futures::{
@ -437,18 +435,6 @@ impl Server {
.await
}
}
})
.add_request_handler(get_cached_embeddings)
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
compute_embeddings(
request,
response,
session,
app_state.config.openai_api_key.clone(),
)
}
});
Arc::new(server)
@ -3780,129 +3766,6 @@ impl RateLimit for FreeCountLanguageModelTokensRateLimit {
}
}
struct ZedProComputeEmbeddingsRateLimit;
impl RateLimit for ZedProComputeEmbeddingsRateLimit {
fn capacity(&self) -> usize {
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5000) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"zed-pro:compute-embeddings"
}
}
struct FreeComputeEmbeddingsRateLimit;
impl RateLimit for FreeComputeEmbeddingsRateLimit {
fn capacity(&self) -> usize {
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5000 / 10) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"free:compute-embeddings"
}
}
async fn compute_embeddings(
request: proto::ComputeEmbeddings,
response: Response<proto::ComputeEmbeddings>,
session: Session,
api_key: Option<Arc<str>>,
) -> Result<()> {
let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
proto::Plan::Free | proto::Plan::ZedProTrial => Box::new(FreeComputeEmbeddingsRateLimit),
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
let embeddings = match request.model.as_str() {
"openai/text-embedding-3-small" => {
open_ai::embed(
session.http_client.as_ref(),
OPEN_AI_API_URL,
&api_key,
OpenAiEmbeddingModel::TextEmbedding3Small,
request.texts.iter().map(|text| text.as_str()),
)
.await?
}
provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
};
let embeddings = request
.texts
.iter()
.map(|text| {
let mut hasher = sha2::Sha256::new();
hasher.update(text.as_bytes());
let result = hasher.finalize();
result.to_vec()
})
.zip(
embeddings
.data
.into_iter()
.map(|embedding| embedding.embedding),
)
.collect::<HashMap<_, _>>();
let db = session.db().await;
db.save_embeddings(&request.model, &embeddings)
.await
.context("failed to save embeddings")
.trace_err();
response.send(proto::ComputeEmbeddingsResponse {
embeddings: embeddings
.into_iter()
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
.collect(),
})?;
Ok(())
}
async fn get_cached_embeddings(
request: proto::GetCachedEmbeddings,
response: Response<proto::GetCachedEmbeddings>,
session: Session,
) -> Result<()> {
authorize_access_to_legacy_llm_endpoints(&session).await?;
let db = session.db().await;
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
response.send(proto::GetCachedEmbeddingsResponse {
embeddings: embeddings
.into_iter()
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
.collect(),
})?;
Ok(())
}
/// This is leftover from before the LLM service.
///
/// The endpoints protected by this check will be moved there eventually.

View file

@ -188,26 +188,3 @@ enum LanguageModelProvider {
Google = 2;
Zed = 3;
}
message GetCachedEmbeddings {
string model = 1;
repeated bytes digests = 2;
}
message GetCachedEmbeddingsResponse {
repeated Embedding embeddings = 1;
}
message ComputeEmbeddings {
string model = 1;
repeated string texts = 2;
}
message ComputeEmbeddingsResponse {
repeated Embedding embeddings = 1;
}
message Embedding {
bytes digest = 1;
repeated float dimensions = 2;
}

View file

@ -208,10 +208,6 @@ message Envelope {
CountLanguageModelTokens count_language_model_tokens = 230;
CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
GetCachedEmbeddings get_cached_embeddings = 189;
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
ComputeEmbeddings compute_embeddings = 191;
ComputeEmbeddingsResponse compute_embeddings_response = 192;
UpdateChannelMessage update_channel_message = 170;
ChannelMessageUpdate channel_message_update = 171;
@ -394,6 +390,7 @@ message Envelope {
reserved 166 to 169;
reserved 177 to 185;
reserved 188;
reserved 189 to 192;
reserved 193 to 195;
reserved 197;
reserved 200 to 202;

View file

@ -49,8 +49,6 @@ messages!(
(ChannelMessageUpdate, Foreground),
(CloseBuffer, Foreground),
(Commit, Background),
(ComputeEmbeddings, Background),
(ComputeEmbeddingsResponse, Background),
(CopyProjectEntry, Foreground),
(CountLanguageModelTokens, Background),
(CountLanguageModelTokensResponse, Background),
@ -82,8 +80,6 @@ messages!(
(FormatBuffers, Foreground),
(FormatBuffersResponse, Foreground),
(FuzzySearchUsers, Foreground),
(GetCachedEmbeddings, Background),
(GetCachedEmbeddingsResponse, Background),
(GetChannelMembers, Foreground),
(GetChannelMembersResponse, Foreground),
(GetChannelMessages, Background),
@ -319,7 +315,6 @@ request_messages!(
(CancelCall, Ack),
(Commit, Ack),
(CopyProjectEntry, ProjectEntryResponse),
(ComputeEmbeddings, ComputeEmbeddingsResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),
@ -332,7 +327,6 @@ request_messages!(
(ApplyCodeActionKind, ApplyCodeActionKindResponse),
(FormatBuffers, FormatBuffersResponse),
(FuzzySearchUsers, UsersResponse),
(GetCachedEmbeddings, GetCachedEmbeddingsResponse),
(GetChannelMembers, GetChannelMembersResponse),
(GetChannelMessages, GetChannelMessagesResponse),
(GetChannelMessagesById, GetChannelMessagesResponse),

View file

@ -1,9 +1,7 @@
mod cloud;
mod lmstudio;
mod ollama;
mod open_ai;
pub use cloud::*;
pub use lmstudio::*;
pub use ollama::*;
pub use open_ai::*;

View file

@ -1,93 +0,0 @@
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
use anyhow::{Context as _, Result, anyhow};
use client::{Client, proto};
use collections::HashMap;
use futures::{FutureExt, future::BoxFuture};
use std::sync::Arc;
pub struct CloudEmbeddingProvider {
model: String,
client: Arc<Client>,
}
impl CloudEmbeddingProvider {
pub fn new(client: Arc<Client>) -> Self {
Self {
model: "openai/text-embedding-3-small".into(),
client,
}
}
}
impl EmbeddingProvider for CloudEmbeddingProvider {
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
// First, fetch any embeddings that are cached based on the requested texts' digests
// Then compute any embeddings that are missing.
async move {
if !self.client.status().borrow().is_connected() {
return Err(anyhow!("sign in required"));
}
let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
model: self.model.clone(),
digests: texts
.iter()
.map(|to_embed| to_embed.digest.to_vec())
.collect(),
});
let mut embeddings = cached_embeddings
.await
.context("failed to fetch cached embeddings via cloud model")?
.embeddings
.into_iter()
.map(|embedding| {
let digest: [u8; 32] = embedding
.digest
.try_into()
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
Ok((digest, embedding.dimensions))
})
.collect::<Result<HashMap<_, _>>>()?;
let compute_embeddings_request = proto::ComputeEmbeddings {
model: self.model.clone(),
texts: texts
.iter()
.filter_map(|to_embed| {
if embeddings.contains_key(&to_embed.digest) {
None
} else {
Some(to_embed.text.to_string())
}
})
.collect(),
};
if !compute_embeddings_request.texts.is_empty() {
let missing_embeddings = self.client.request(compute_embeddings_request).await?;
for embedding in missing_embeddings.embeddings {
let digest: [u8; 32] = embedding
.digest
.try_into()
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
embeddings.insert(digest, embedding.dimensions);
}
}
texts
.iter()
.map(|to_embed| {
let embedding =
embeddings.get(&to_embed.digest).cloned().with_context(|| {
format!("server did not return an embedding for {:?}", to_embed)
})?;
Ok(Embedding::new(embedding))
})
.collect()
}
.boxed()
}
fn batch_size(&self) -> usize {
2048
}
}