From ecc600a68f52b421092384ab2f5947984d8fa677 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 23 Apr 2025 18:27:46 -0400 Subject: [PATCH] collab: Remove code for embeddings (#29310) This PR removes the embeddings-related code from collab and the protocol, as we weren't using it anywhere. Release Notes: - N/A --- Cargo.lock | 1 - crates/collab/Cargo.toml | 1 - crates/collab/src/rpc.rs | 137 ------------------- crates/proto/proto/ai.proto | 23 ---- crates/proto/proto/zed.proto | 5 +- crates/proto/src/proto.rs | 6 - crates/semantic_index/src/embedding.rs | 2 - crates/semantic_index/src/embedding/cloud.rs | 93 ------------- 8 files changed, 1 insertion(+), 267 deletions(-) delete mode 100644 crates/semantic_index/src/embedding/cloud.rs diff --git a/Cargo.lock b/Cargo.lock index 9f24b3d82c..1b4e88b2a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3017,7 +3017,6 @@ dependencies = [ "nanoid", "node_runtime", "notifications", - "open_ai", "parking_lot", "pretty_assertions", "project", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index b446e264ac..d6df4be6df 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -41,7 +41,6 @@ jsonwebtoken.workspace = true livekit_api.workspace = true log.workspace = true nanoid.workspace = true -open_ai.workspace = true parking_lot.workspace = true prometheus = "0.14" prost.workspace = true diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 0e6c386a86..76f5f94641 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -34,10 +34,8 @@ use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; use http_client::HttpClient; -use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel}; use reqwest_client::ReqwestClient; use rpc::proto::split_repository_update; -use sha2::Digest; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; use futures::{ @@ -437,18 +435,6 @@ impl Server { .await } } - }) - .add_request_handler(get_cached_embeddings) - .add_request_handler({ - let app_state = app_state.clone(); - move |request, response, session| { - compute_embeddings( - request, - response, - session, - app_state.config.openai_api_key.clone(), - ) - } }); Arc::new(server) @@ -3780,129 +3766,6 @@ impl RateLimit for FreeCountLanguageModelTokensRateLimit { } } -struct ZedProComputeEmbeddingsRateLimit; - -impl RateLimit for ZedProComputeEmbeddingsRateLimit { - fn capacity(&self) -> usize { - std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(5000) // Picked arbitrarily - } - - fn refill_duration(&self) -> chrono::Duration { - chrono::Duration::hours(1) - } - - fn db_name(&self) -> &'static str { - "zed-pro:compute-embeddings" - } -} - -struct FreeComputeEmbeddingsRateLimit; - -impl RateLimit for FreeComputeEmbeddingsRateLimit { - fn capacity(&self) -> usize { - std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(5000 / 10) // Picked arbitrarily - } - - fn refill_duration(&self) -> chrono::Duration { - chrono::Duration::hours(1) - } - - fn db_name(&self) -> &'static str { - "free:compute-embeddings" - } -} - -async fn compute_embeddings( - request: proto::ComputeEmbeddings, - response: Response, - session: Session, - api_key: Option>, -) -> Result<()> { - let api_key = api_key.context("no OpenAI API key configured on the server")?; - authorize_access_to_legacy_llm_endpoints(&session).await?; - - let rate_limit: Box = match session.current_plan(&session.db().await).await? { - proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit), - proto::Plan::Free | proto::Plan::ZedProTrial => Box::new(FreeComputeEmbeddingsRateLimit), - }; - - session - .app_state - .rate_limiter - .check(&*rate_limit, session.user_id()) - .await?; - - let embeddings = match request.model.as_str() { - "openai/text-embedding-3-small" => { - open_ai::embed( - session.http_client.as_ref(), - OPEN_AI_API_URL, - &api_key, - OpenAiEmbeddingModel::TextEmbedding3Small, - request.texts.iter().map(|text| text.as_str()), - ) - .await? - } - provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?, - }; - - let embeddings = request - .texts - .iter() - .map(|text| { - let mut hasher = sha2::Sha256::new(); - hasher.update(text.as_bytes()); - let result = hasher.finalize(); - result.to_vec() - }) - .zip( - embeddings - .data - .into_iter() - .map(|embedding| embedding.embedding), - ) - .collect::>(); - - let db = session.db().await; - db.save_embeddings(&request.model, &embeddings) - .await - .context("failed to save embeddings") - .trace_err(); - - response.send(proto::ComputeEmbeddingsResponse { - embeddings: embeddings - .into_iter() - .map(|(digest, dimensions)| proto::Embedding { digest, dimensions }) - .collect(), - })?; - Ok(()) -} - -async fn get_cached_embeddings( - request: proto::GetCachedEmbeddings, - response: Response, - session: Session, -) -> Result<()> { - authorize_access_to_legacy_llm_endpoints(&session).await?; - - let db = session.db().await; - let embeddings = db.get_embeddings(&request.model, &request.digests).await?; - - response.send(proto::GetCachedEmbeddingsResponse { - embeddings: embeddings - .into_iter() - .map(|(digest, dimensions)| proto::Embedding { digest, dimensions }) - .collect(), - })?; - Ok(()) -} - /// This is leftover from before the LLM service. /// /// The endpoints protected by this check will be moved there eventually. diff --git a/crates/proto/proto/ai.proto b/crates/proto/proto/ai.proto index 6b6b33df59..c39345c3c2 100644 --- a/crates/proto/proto/ai.proto +++ b/crates/proto/proto/ai.proto @@ -188,26 +188,3 @@ enum LanguageModelProvider { Google = 2; Zed = 3; } - -message GetCachedEmbeddings { - string model = 1; - repeated bytes digests = 2; -} - -message GetCachedEmbeddingsResponse { - repeated Embedding embeddings = 1; -} - -message ComputeEmbeddings { - string model = 1; - repeated string texts = 2; -} - -message ComputeEmbeddingsResponse { - repeated Embedding embeddings = 1; -} - -message Embedding { - bytes digest = 1; - repeated float dimensions = 2; -} diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index a183a1de5b..ef9f46ac49 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -208,10 +208,6 @@ message Envelope { CountLanguageModelTokens count_language_model_tokens = 230; CountLanguageModelTokensResponse count_language_model_tokens_response = 231; - GetCachedEmbeddings get_cached_embeddings = 189; - GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; - ComputeEmbeddings compute_embeddings = 191; - ComputeEmbeddingsResponse compute_embeddings_response = 192; UpdateChannelMessage update_channel_message = 170; ChannelMessageUpdate channel_message_update = 171; @@ -394,6 +390,7 @@ message Envelope { reserved 166 to 169; reserved 177 to 185; reserved 188; + reserved 189 to 192; reserved 193 to 195; reserved 197; reserved 200 to 202; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index fb38c282df..2f34bd2cd8 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -49,8 +49,6 @@ messages!( (ChannelMessageUpdate, Foreground), (CloseBuffer, Foreground), (Commit, Background), - (ComputeEmbeddings, Background), - (ComputeEmbeddingsResponse, Background), (CopyProjectEntry, Foreground), (CountLanguageModelTokens, Background), (CountLanguageModelTokensResponse, Background), @@ -82,8 +80,6 @@ messages!( (FormatBuffers, Foreground), (FormatBuffersResponse, Foreground), (FuzzySearchUsers, Foreground), - (GetCachedEmbeddings, Background), - (GetCachedEmbeddingsResponse, Background), (GetChannelMembers, Foreground), (GetChannelMembersResponse, Foreground), (GetChannelMessages, Background), @@ -319,7 +315,6 @@ request_messages!( (CancelCall, Ack), (Commit, Ack), (CopyProjectEntry, ProjectEntryResponse), - (ComputeEmbeddings, ComputeEmbeddingsResponse), (CreateChannel, CreateChannelResponse), (CreateProjectEntry, ProjectEntryResponse), (CreateRoom, CreateRoomResponse), @@ -332,7 +327,6 @@ request_messages!( (ApplyCodeActionKind, ApplyCodeActionKindResponse), (FormatBuffers, FormatBuffersResponse), (FuzzySearchUsers, UsersResponse), - (GetCachedEmbeddings, GetCachedEmbeddingsResponse), (GetChannelMembers, GetChannelMembersResponse), (GetChannelMessages, GetChannelMessagesResponse), (GetChannelMessagesById, GetChannelMessagesResponse), diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index ef4443283f..8ca47a4023 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -1,9 +1,7 @@ -mod cloud; mod lmstudio; mod ollama; mod open_ai; -pub use cloud::*; pub use lmstudio::*; pub use ollama::*; pub use open_ai::*; diff --git a/crates/semantic_index/src/embedding/cloud.rs b/crates/semantic_index/src/embedding/cloud.rs deleted file mode 100644 index fdbfca5873..0000000000 --- a/crates/semantic_index/src/embedding/cloud.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::{Embedding, EmbeddingProvider, TextToEmbed}; -use anyhow::{Context as _, Result, anyhow}; -use client::{Client, proto}; -use collections::HashMap; -use futures::{FutureExt, future::BoxFuture}; -use std::sync::Arc; - -pub struct CloudEmbeddingProvider { - model: String, - client: Arc, -} - -impl CloudEmbeddingProvider { - pub fn new(client: Arc) -> Self { - Self { - model: "openai/text-embedding-3-small".into(), - client, - } - } -} - -impl EmbeddingProvider for CloudEmbeddingProvider { - fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result>> { - // First, fetch any embeddings that are cached based on the requested texts' digests - // Then compute any embeddings that are missing. - async move { - if !self.client.status().borrow().is_connected() { - return Err(anyhow!("sign in required")); - } - - let cached_embeddings = self.client.request(proto::GetCachedEmbeddings { - model: self.model.clone(), - digests: texts - .iter() - .map(|to_embed| to_embed.digest.to_vec()) - .collect(), - }); - let mut embeddings = cached_embeddings - .await - .context("failed to fetch cached embeddings via cloud model")? - .embeddings - .into_iter() - .map(|embedding| { - let digest: [u8; 32] = embedding - .digest - .try_into() - .map_err(|_| anyhow!("invalid digest for cached embedding"))?; - Ok((digest, embedding.dimensions)) - }) - .collect::>>()?; - - let compute_embeddings_request = proto::ComputeEmbeddings { - model: self.model.clone(), - texts: texts - .iter() - .filter_map(|to_embed| { - if embeddings.contains_key(&to_embed.digest) { - None - } else { - Some(to_embed.text.to_string()) - } - }) - .collect(), - }; - if !compute_embeddings_request.texts.is_empty() { - let missing_embeddings = self.client.request(compute_embeddings_request).await?; - for embedding in missing_embeddings.embeddings { - let digest: [u8; 32] = embedding - .digest - .try_into() - .map_err(|_| anyhow!("invalid digest for cached embedding"))?; - embeddings.insert(digest, embedding.dimensions); - } - } - - texts - .iter() - .map(|to_embed| { - let embedding = - embeddings.get(&to_embed.digest).cloned().with_context(|| { - format!("server did not return an embedding for {:?}", to_embed) - })?; - Ok(Embedding::new(embedding)) - }) - .collect() - } - .boxed() - } - - fn batch_size(&self) -> usize { - 2048 - } -}