Add feature-flagged access to LLM service (#16136)
This PR adds feature-flagged access to the LLM service. We've repurposed the `language-models` feature flag to be used for providing access to Claude 3.5 Sonnet through the Zed provider. The remaining RPC endpoints that were previously behind the `language-models` feature flag are now behind a staff check. We also put some Zed Pro related messaging behind a feature flag. Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev>
This commit is contained in:
parent
3bebb8b401
commit
8a148f3a13
2 changed files with 49 additions and 43 deletions
|
@ -4501,7 +4501,7 @@ async fn count_language_model_tokens(
|
|||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
|
||||
|
@ -4621,7 +4621,7 @@ async fn compute_embeddings(
|
|||
api_key: Option<Arc<str>>,
|
||||
) -> Result<()> {
|
||||
let api_key = api_key.context("no OpenAI API key configured on the server")?;
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
|
||||
|
@ -4685,7 +4685,7 @@ async fn get_cached_embeddings(
|
|||
response: Response<proto::GetCachedEmbeddings>,
|
||||
session: UserSession,
|
||||
) -> Result<()> {
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
|
||||
let db = session.db().await;
|
||||
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
|
||||
|
@ -4699,14 +4699,15 @@ async fn get_cached_embeddings(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
|
||||
let db = session.db().await;
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
if flags.iter().any(|flag| flag == "language-models") {
|
||||
return Ok(());
|
||||
/// This is leftover from before the LLM service.
|
||||
///
|
||||
/// The endpoints protected by this check will be moved there eventually.
|
||||
async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
|
||||
if session.is_staff() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
|
||||
/// Get a Supermaven API key for the user
|
||||
|
@ -4915,12 +4916,13 @@ async fn get_llm_api_token(
|
|||
response: Response<proto::GetLlmToken>,
|
||||
session: UserSession,
|
||||
) -> Result<()> {
|
||||
if !session.is_staff() {
|
||||
let db = session.db().await;
|
||||
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
|
||||
let db = session.db().await;
|
||||
|
||||
let user_id = session.user_id();
|
||||
let user = db
|
||||
.get_user_by_id(user_id)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue