Add feature-flagged access to LLM service (#16136)

This PR adds feature-flagged access to the LLM service.

We've repurposed the `language-models` feature flag to be used for
providing access to Claude 3.5 Sonnet through the Zed provider.

The remaining RPC endpoints that were previously behind the
`language-models` feature flag are now behind a staff check.

We also put some Zed Pro related messaging behind a feature flag.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
This commit is contained in:
Marshall Bowers 2024-08-12 18:13:40 -04:00 committed by GitHub
parent 3bebb8b401
commit 8a148f3a13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 49 additions and 43 deletions

View file

@ -4501,7 +4501,7 @@ async fn count_language_model_tokens(
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
@ -4621,7 +4621,7 @@ async fn compute_embeddings(
api_key: Option<Arc<str>>,
) -> Result<()> {
let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_language_models(&session).await?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
@ -4685,7 +4685,7 @@ async fn get_cached_embeddings(
response: Response<proto::GetCachedEmbeddings>,
session: UserSession,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
let db = session.db().await;
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
@ -4699,14 +4699,15 @@ async fn get_cached_embeddings(
Ok(())
}
async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
if flags.iter().any(|flag| flag == "language-models") {
return Ok(());
/// This is leftover from before the LLM service.
///
/// The endpoints protected by this check will be moved there eventually.
async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
if session.is_staff() {
Ok(())
} else {
Err(anyhow!("permission denied"))?
}
Err(anyhow!("permission denied"))?
}
/// Get a Supermaven API key for the user
@ -4915,12 +4916,13 @@ async fn get_llm_api_token(
response: Response<proto::GetLlmToken>,
session: UserSession,
) -> Result<()> {
if !session.is_staff() {
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
Err(anyhow!("permission denied"))?
}
let db = session.db().await;
let user_id = session.user_id();
let user = db
.get_user_by_id(user_id)