collab: Remove code for embeddings (#29310)
This PR removes the embeddings-related code from collab and the protocol, as we weren't using it anywhere. Release Notes: - N/A
This commit is contained in:
parent
218496744c
commit
ecc600a68f
8 changed files with 1 additions and 267 deletions
|
@ -41,7 +41,6 @@ jsonwebtoken.workspace = true
|
|||
livekit_api.workspace = true
|
||||
log.workspace = true
|
||||
nanoid.workspace = true
|
||||
open_ai.workspace = true
|
||||
parking_lot.workspace = true
|
||||
prometheus = "0.14"
|
||||
prost.workspace = true
|
||||
|
|
|
@ -34,10 +34,8 @@ use collections::{HashMap, HashSet};
|
|||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
use http_client::HttpClient;
|
||||
use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use rpc::proto::split_repository_update;
|
||||
use sha2::Digest;
|
||||
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
||||
|
||||
use futures::{
|
||||
|
@ -437,18 +435,6 @@ impl Server {
|
|||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_request_handler(get_cached_embeddings)
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
compute_embeddings(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
app_state.config.openai_api_key.clone(),
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
Arc::new(server)
|
||||
|
@ -3780,129 +3766,6 @@ impl RateLimit for FreeCountLanguageModelTokensRateLimit {
|
|||
}
|
||||
}
|
||||
|
||||
struct ZedProComputeEmbeddingsRateLimit;
|
||||
|
||||
impl RateLimit for ZedProComputeEmbeddingsRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(5000) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"zed-pro:compute-embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
struct FreeComputeEmbeddingsRateLimit;
|
||||
|
||||
impl RateLimit for FreeComputeEmbeddingsRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(5000 / 10) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"free:compute-embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
async fn compute_embeddings(
|
||||
request: proto::ComputeEmbeddings,
|
||||
response: Response<proto::ComputeEmbeddings>,
|
||||
session: Session,
|
||||
api_key: Option<Arc<str>>,
|
||||
) -> Result<()> {
|
||||
let api_key = api_key.context("no OpenAI API key configured on the server")?;
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
|
||||
proto::Plan::Free | proto::Plan::ZedProTrial => Box::new(FreeComputeEmbeddingsRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.app_state
|
||||
.rate_limiter
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
let embeddings = match request.model.as_str() {
|
||||
"openai/text-embedding-3-small" => {
|
||||
open_ai::embed(
|
||||
session.http_client.as_ref(),
|
||||
OPEN_AI_API_URL,
|
||||
&api_key,
|
||||
OpenAiEmbeddingModel::TextEmbedding3Small,
|
||||
request.texts.iter().map(|text| text.as_str()),
|
||||
)
|
||||
.await?
|
||||
}
|
||||
provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
|
||||
};
|
||||
|
||||
let embeddings = request
|
||||
.texts
|
||||
.iter()
|
||||
.map(|text| {
|
||||
let mut hasher = sha2::Sha256::new();
|
||||
hasher.update(text.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
result.to_vec()
|
||||
})
|
||||
.zip(
|
||||
embeddings
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| embedding.embedding),
|
||||
)
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
let db = session.db().await;
|
||||
db.save_embeddings(&request.model, &embeddings)
|
||||
.await
|
||||
.context("failed to save embeddings")
|
||||
.trace_err();
|
||||
|
||||
response.send(proto::ComputeEmbeddingsResponse {
|
||||
embeddings: embeddings
|
||||
.into_iter()
|
||||
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
|
||||
.collect(),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_cached_embeddings(
|
||||
request: proto::GetCachedEmbeddings,
|
||||
response: Response<proto::GetCachedEmbeddings>,
|
||||
session: Session,
|
||||
) -> Result<()> {
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
|
||||
let db = session.db().await;
|
||||
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
|
||||
|
||||
response.send(proto::GetCachedEmbeddingsResponse {
|
||||
embeddings: embeddings
|
||||
.into_iter()
|
||||
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
|
||||
.collect(),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This is leftover from before the LLM service.
|
||||
///
|
||||
/// The endpoints protected by this check will be moved there eventually.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue