collab: Remove CountLanguageModelTokens RPC message (#29314)

This PR removes the `CountLanguageModelTokens` RPC message from collab.

We were only using this for Google AI models through the Zed provider
(which is only available to Zed staff).

For now we're returning `0`, but will bring back soon.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-23 19:10:47 -04:00 committed by GitHub
parent ba3d82629e
commit 74442b68ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 4 additions and 151 deletions

View file

@ -3,7 +3,7 @@ mod connection_pool;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::llm::LlmTokenClaims;
use crate::{
AppState, Config, Error, RateLimit, Result, auth,
AppState, Error, Result, auth,
db::{
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
@ -33,7 +33,6 @@ use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use http_client::HttpClient;
use reqwest_client::ReqwestClient;
use rpc::proto::split_repository_update;
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
@ -132,7 +131,6 @@ struct Session {
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
app_state: Arc<AppState>,
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<dyn HttpClient>,
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
@ -425,17 +423,7 @@ impl Server {
.add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
.add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
count_language_model_tokens(request, response, session, &app_state.config)
.await
}
}
});
.add_message_handler(update_context);
Arc::new(server)
}
@ -764,7 +752,6 @@ impl Server {
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
app_state: this.app_state.clone(),
http_client,
geoip_country_code,
system_id,
_executor: executor.clone(),
@ -3683,100 +3670,6 @@ async fn acknowledge_buffer_version(
Ok(())
}
async fn count_language_model_tokens(
request: proto::CountLanguageModelTokens,
response: Response<proto::CountLanguageModelTokens>,
session: Session,
config: &Config,
) -> Result<()> {
authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
proto::Plan::Free | proto::Plan::ZedProTrial => {
Box::new(FreeCountLanguageModelTokensRateLimit)
}
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Google) => {
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CountLanguageModelTokensResponse {
token_count: result.total_tokens as u32,
})?;
Ok(())
}
struct ZedProCountLanguageModelTokensRateLimit;
impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
fn capacity(&self) -> usize {
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"zed-pro:count-language-model-tokens"
}
}
struct FreeCountLanguageModelTokensRateLimit;
impl RateLimit for FreeCountLanguageModelTokensRateLimit {
fn capacity(&self) -> usize {
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600 / 10) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"free:count-language-model-tokens"
}
}
/// This is leftover from before the LLM service.
///
/// The endpoints protected by this check will be moved there eventually.
async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
if session.is_staff() {
Ok(())
} else {
Err(anyhow!("permission denied"))?
}
}
/// Get a Supermaven API key for the user
async fn get_supermaven_api_key(
_request: proto::GetSupermavenApiKey,