collab: Remove code for embeddings (#29310)
This PR removes the embeddings-related code from collab and the protocol, as we weren't using it anywhere. Release Notes: - N/A
This commit is contained in:
parent
218496744c
commit
ecc600a68f
8 changed files with 1 additions and 267 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -3017,7 +3017,6 @@ dependencies = [
|
||||||
"nanoid",
|
"nanoid",
|
||||||
"node_runtime",
|
"node_runtime",
|
||||||
"notifications",
|
"notifications",
|
||||||
"open_ai",
|
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
"project",
|
"project",
|
||||||
|
|
|
@ -41,7 +41,6 @@ jsonwebtoken.workspace = true
|
||||||
livekit_api.workspace = true
|
livekit_api.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
nanoid.workspace = true
|
nanoid.workspace = true
|
||||||
open_ai.workspace = true
|
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
prometheus = "0.14"
|
prometheus = "0.14"
|
||||||
prost.workspace = true
|
prost.workspace = true
|
||||||
|
|
|
@ -34,10 +34,8 @@ use collections::{HashMap, HashSet};
|
||||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||||
use core::fmt::{self, Debug, Formatter};
|
use core::fmt::{self, Debug, Formatter};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
|
|
||||||
use reqwest_client::ReqwestClient;
|
use reqwest_client::ReqwestClient;
|
||||||
use rpc::proto::split_repository_update;
|
use rpc::proto::split_repository_update;
|
||||||
use sha2::Digest;
|
|
||||||
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
||||||
|
|
||||||
use futures::{
|
use futures::{
|
||||||
|
@ -437,18 +435,6 @@ impl Server {
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
.add_request_handler(get_cached_embeddings)
|
|
||||||
.add_request_handler({
|
|
||||||
let app_state = app_state.clone();
|
|
||||||
move |request, response, session| {
|
|
||||||
compute_embeddings(
|
|
||||||
request,
|
|
||||||
response,
|
|
||||||
session,
|
|
||||||
app_state.config.openai_api_key.clone(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
Arc::new(server)
|
Arc::new(server)
|
||||||
|
@ -3780,129 +3766,6 @@ impl RateLimit for FreeCountLanguageModelTokensRateLimit {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ZedProComputeEmbeddingsRateLimit;
|
|
||||||
|
|
||||||
impl RateLimit for ZedProComputeEmbeddingsRateLimit {
|
|
||||||
fn capacity(&self) -> usize {
|
|
||||||
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
|
|
||||||
.ok()
|
|
||||||
.and_then(|v| v.parse().ok())
|
|
||||||
.unwrap_or(5000) // Picked arbitrarily
|
|
||||||
}
|
|
||||||
|
|
||||||
fn refill_duration(&self) -> chrono::Duration {
|
|
||||||
chrono::Duration::hours(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn db_name(&self) -> &'static str {
|
|
||||||
"zed-pro:compute-embeddings"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct FreeComputeEmbeddingsRateLimit;
|
|
||||||
|
|
||||||
impl RateLimit for FreeComputeEmbeddingsRateLimit {
|
|
||||||
fn capacity(&self) -> usize {
|
|
||||||
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
|
|
||||||
.ok()
|
|
||||||
.and_then(|v| v.parse().ok())
|
|
||||||
.unwrap_or(5000 / 10) // Picked arbitrarily
|
|
||||||
}
|
|
||||||
|
|
||||||
fn refill_duration(&self) -> chrono::Duration {
|
|
||||||
chrono::Duration::hours(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn db_name(&self) -> &'static str {
|
|
||||||
"free:compute-embeddings"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn compute_embeddings(
|
|
||||||
request: proto::ComputeEmbeddings,
|
|
||||||
response: Response<proto::ComputeEmbeddings>,
|
|
||||||
session: Session,
|
|
||||||
api_key: Option<Arc<str>>,
|
|
||||||
) -> Result<()> {
|
|
||||||
let api_key = api_key.context("no OpenAI API key configured on the server")?;
|
|
||||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
|
||||||
|
|
||||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
|
|
||||||
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
|
|
||||||
proto::Plan::Free | proto::Plan::ZedProTrial => Box::new(FreeComputeEmbeddingsRateLimit),
|
|
||||||
};
|
|
||||||
|
|
||||||
session
|
|
||||||
.app_state
|
|
||||||
.rate_limiter
|
|
||||||
.check(&*rate_limit, session.user_id())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let embeddings = match request.model.as_str() {
|
|
||||||
"openai/text-embedding-3-small" => {
|
|
||||||
open_ai::embed(
|
|
||||||
session.http_client.as_ref(),
|
|
||||||
OPEN_AI_API_URL,
|
|
||||||
&api_key,
|
|
||||||
OpenAiEmbeddingModel::TextEmbedding3Small,
|
|
||||||
request.texts.iter().map(|text| text.as_str()),
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
}
|
|
||||||
provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
|
|
||||||
};
|
|
||||||
|
|
||||||
let embeddings = request
|
|
||||||
.texts
|
|
||||||
.iter()
|
|
||||||
.map(|text| {
|
|
||||||
let mut hasher = sha2::Sha256::new();
|
|
||||||
hasher.update(text.as_bytes());
|
|
||||||
let result = hasher.finalize();
|
|
||||||
result.to_vec()
|
|
||||||
})
|
|
||||||
.zip(
|
|
||||||
embeddings
|
|
||||||
.data
|
|
||||||
.into_iter()
|
|
||||||
.map(|embedding| embedding.embedding),
|
|
||||||
)
|
|
||||||
.collect::<HashMap<_, _>>();
|
|
||||||
|
|
||||||
let db = session.db().await;
|
|
||||||
db.save_embeddings(&request.model, &embeddings)
|
|
||||||
.await
|
|
||||||
.context("failed to save embeddings")
|
|
||||||
.trace_err();
|
|
||||||
|
|
||||||
response.send(proto::ComputeEmbeddingsResponse {
|
|
||||||
embeddings: embeddings
|
|
||||||
.into_iter()
|
|
||||||
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
|
|
||||||
.collect(),
|
|
||||||
})?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_cached_embeddings(
|
|
||||||
request: proto::GetCachedEmbeddings,
|
|
||||||
response: Response<proto::GetCachedEmbeddings>,
|
|
||||||
session: Session,
|
|
||||||
) -> Result<()> {
|
|
||||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
|
||||||
|
|
||||||
let db = session.db().await;
|
|
||||||
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
|
|
||||||
|
|
||||||
response.send(proto::GetCachedEmbeddingsResponse {
|
|
||||||
embeddings: embeddings
|
|
||||||
.into_iter()
|
|
||||||
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
|
|
||||||
.collect(),
|
|
||||||
})?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This is leftover from before the LLM service.
|
/// This is leftover from before the LLM service.
|
||||||
///
|
///
|
||||||
/// The endpoints protected by this check will be moved there eventually.
|
/// The endpoints protected by this check will be moved there eventually.
|
||||||
|
|
|
@ -188,26 +188,3 @@ enum LanguageModelProvider {
|
||||||
Google = 2;
|
Google = 2;
|
||||||
Zed = 3;
|
Zed = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GetCachedEmbeddings {
|
|
||||||
string model = 1;
|
|
||||||
repeated bytes digests = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GetCachedEmbeddingsResponse {
|
|
||||||
repeated Embedding embeddings = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message ComputeEmbeddings {
|
|
||||||
string model = 1;
|
|
||||||
repeated string texts = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message ComputeEmbeddingsResponse {
|
|
||||||
repeated Embedding embeddings = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Embedding {
|
|
||||||
bytes digest = 1;
|
|
||||||
repeated float dimensions = 2;
|
|
||||||
}
|
|
||||||
|
|
|
@ -208,10 +208,6 @@ message Envelope {
|
||||||
|
|
||||||
CountLanguageModelTokens count_language_model_tokens = 230;
|
CountLanguageModelTokens count_language_model_tokens = 230;
|
||||||
CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
|
CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
|
||||||
GetCachedEmbeddings get_cached_embeddings = 189;
|
|
||||||
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
|
|
||||||
ComputeEmbeddings compute_embeddings = 191;
|
|
||||||
ComputeEmbeddingsResponse compute_embeddings_response = 192;
|
|
||||||
|
|
||||||
UpdateChannelMessage update_channel_message = 170;
|
UpdateChannelMessage update_channel_message = 170;
|
||||||
ChannelMessageUpdate channel_message_update = 171;
|
ChannelMessageUpdate channel_message_update = 171;
|
||||||
|
@ -394,6 +390,7 @@ message Envelope {
|
||||||
reserved 166 to 169;
|
reserved 166 to 169;
|
||||||
reserved 177 to 185;
|
reserved 177 to 185;
|
||||||
reserved 188;
|
reserved 188;
|
||||||
|
reserved 189 to 192;
|
||||||
reserved 193 to 195;
|
reserved 193 to 195;
|
||||||
reserved 197;
|
reserved 197;
|
||||||
reserved 200 to 202;
|
reserved 200 to 202;
|
||||||
|
|
|
@ -49,8 +49,6 @@ messages!(
|
||||||
(ChannelMessageUpdate, Foreground),
|
(ChannelMessageUpdate, Foreground),
|
||||||
(CloseBuffer, Foreground),
|
(CloseBuffer, Foreground),
|
||||||
(Commit, Background),
|
(Commit, Background),
|
||||||
(ComputeEmbeddings, Background),
|
|
||||||
(ComputeEmbeddingsResponse, Background),
|
|
||||||
(CopyProjectEntry, Foreground),
|
(CopyProjectEntry, Foreground),
|
||||||
(CountLanguageModelTokens, Background),
|
(CountLanguageModelTokens, Background),
|
||||||
(CountLanguageModelTokensResponse, Background),
|
(CountLanguageModelTokensResponse, Background),
|
||||||
|
@ -82,8 +80,6 @@ messages!(
|
||||||
(FormatBuffers, Foreground),
|
(FormatBuffers, Foreground),
|
||||||
(FormatBuffersResponse, Foreground),
|
(FormatBuffersResponse, Foreground),
|
||||||
(FuzzySearchUsers, Foreground),
|
(FuzzySearchUsers, Foreground),
|
||||||
(GetCachedEmbeddings, Background),
|
|
||||||
(GetCachedEmbeddingsResponse, Background),
|
|
||||||
(GetChannelMembers, Foreground),
|
(GetChannelMembers, Foreground),
|
||||||
(GetChannelMembersResponse, Foreground),
|
(GetChannelMembersResponse, Foreground),
|
||||||
(GetChannelMessages, Background),
|
(GetChannelMessages, Background),
|
||||||
|
@ -319,7 +315,6 @@ request_messages!(
|
||||||
(CancelCall, Ack),
|
(CancelCall, Ack),
|
||||||
(Commit, Ack),
|
(Commit, Ack),
|
||||||
(CopyProjectEntry, ProjectEntryResponse),
|
(CopyProjectEntry, ProjectEntryResponse),
|
||||||
(ComputeEmbeddings, ComputeEmbeddingsResponse),
|
|
||||||
(CreateChannel, CreateChannelResponse),
|
(CreateChannel, CreateChannelResponse),
|
||||||
(CreateProjectEntry, ProjectEntryResponse),
|
(CreateProjectEntry, ProjectEntryResponse),
|
||||||
(CreateRoom, CreateRoomResponse),
|
(CreateRoom, CreateRoomResponse),
|
||||||
|
@ -332,7 +327,6 @@ request_messages!(
|
||||||
(ApplyCodeActionKind, ApplyCodeActionKindResponse),
|
(ApplyCodeActionKind, ApplyCodeActionKindResponse),
|
||||||
(FormatBuffers, FormatBuffersResponse),
|
(FormatBuffers, FormatBuffersResponse),
|
||||||
(FuzzySearchUsers, UsersResponse),
|
(FuzzySearchUsers, UsersResponse),
|
||||||
(GetCachedEmbeddings, GetCachedEmbeddingsResponse),
|
|
||||||
(GetChannelMembers, GetChannelMembersResponse),
|
(GetChannelMembers, GetChannelMembersResponse),
|
||||||
(GetChannelMessages, GetChannelMessagesResponse),
|
(GetChannelMessages, GetChannelMessagesResponse),
|
||||||
(GetChannelMessagesById, GetChannelMessagesResponse),
|
(GetChannelMessagesById, GetChannelMessagesResponse),
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
mod cloud;
|
|
||||||
mod lmstudio;
|
mod lmstudio;
|
||||||
mod ollama;
|
mod ollama;
|
||||||
mod open_ai;
|
mod open_ai;
|
||||||
|
|
||||||
pub use cloud::*;
|
|
||||||
pub use lmstudio::*;
|
pub use lmstudio::*;
|
||||||
pub use ollama::*;
|
pub use ollama::*;
|
||||||
pub use open_ai::*;
|
pub use open_ai::*;
|
||||||
|
|
|
@ -1,93 +0,0 @@
|
||||||
use crate::{Embedding, EmbeddingProvider, TextToEmbed};
|
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
|
||||||
use client::{Client, proto};
|
|
||||||
use collections::HashMap;
|
|
||||||
use futures::{FutureExt, future::BoxFuture};
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
pub struct CloudEmbeddingProvider {
|
|
||||||
model: String,
|
|
||||||
client: Arc<Client>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CloudEmbeddingProvider {
|
|
||||||
pub fn new(client: Arc<Client>) -> Self {
|
|
||||||
Self {
|
|
||||||
model: "openai/text-embedding-3-small".into(),
|
|
||||||
client,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EmbeddingProvider for CloudEmbeddingProvider {
|
|
||||||
fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
|
|
||||||
// First, fetch any embeddings that are cached based on the requested texts' digests
|
|
||||||
// Then compute any embeddings that are missing.
|
|
||||||
async move {
|
|
||||||
if !self.client.status().borrow().is_connected() {
|
|
||||||
return Err(anyhow!("sign in required"));
|
|
||||||
}
|
|
||||||
|
|
||||||
let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
|
|
||||||
model: self.model.clone(),
|
|
||||||
digests: texts
|
|
||||||
.iter()
|
|
||||||
.map(|to_embed| to_embed.digest.to_vec())
|
|
||||||
.collect(),
|
|
||||||
});
|
|
||||||
let mut embeddings = cached_embeddings
|
|
||||||
.await
|
|
||||||
.context("failed to fetch cached embeddings via cloud model")?
|
|
||||||
.embeddings
|
|
||||||
.into_iter()
|
|
||||||
.map(|embedding| {
|
|
||||||
let digest: [u8; 32] = embedding
|
|
||||||
.digest
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
|
|
||||||
Ok((digest, embedding.dimensions))
|
|
||||||
})
|
|
||||||
.collect::<Result<HashMap<_, _>>>()?;
|
|
||||||
|
|
||||||
let compute_embeddings_request = proto::ComputeEmbeddings {
|
|
||||||
model: self.model.clone(),
|
|
||||||
texts: texts
|
|
||||||
.iter()
|
|
||||||
.filter_map(|to_embed| {
|
|
||||||
if embeddings.contains_key(&to_embed.digest) {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(to_embed.text.to_string())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
};
|
|
||||||
if !compute_embeddings_request.texts.is_empty() {
|
|
||||||
let missing_embeddings = self.client.request(compute_embeddings_request).await?;
|
|
||||||
for embedding in missing_embeddings.embeddings {
|
|
||||||
let digest: [u8; 32] = embedding
|
|
||||||
.digest
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| anyhow!("invalid digest for cached embedding"))?;
|
|
||||||
embeddings.insert(digest, embedding.dimensions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
texts
|
|
||||||
.iter()
|
|
||||||
.map(|to_embed| {
|
|
||||||
let embedding =
|
|
||||||
embeddings.get(&to_embed.digest).cloned().with_context(|| {
|
|
||||||
format!("server did not return an embedding for {:?}", to_embed)
|
|
||||||
})?;
|
|
||||||
Ok(Embedding::new(embedding))
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn batch_size(&self) -> usize {
|
|
||||||
2048
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue