collab: Remove CountLanguageModelTokens
RPC message (#29314)
This PR removes the `CountLanguageModelTokens` RPC message from collab. We were only using this for Google AI models through the Zed provider (which is only available to Zed staff). For now we're returning `0`, but will bring back soon. Release Notes: - N/A
This commit is contained in:
parent
ba3d82629e
commit
74442b68ea
7 changed files with 4 additions and 151 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2998,7 +2998,6 @@ dependencies = [
|
||||||
"git",
|
"git",
|
||||||
"git_hosting_providers",
|
"git_hosting_providers",
|
||||||
"git_ui",
|
"git_ui",
|
||||||
"google_ai",
|
|
||||||
"gpui",
|
"gpui",
|
||||||
"gpui_tokio",
|
"gpui_tokio",
|
||||||
"hex",
|
"hex",
|
||||||
|
|
|
@ -34,7 +34,6 @@ dashmap.workspace = true
|
||||||
derive_more.workspace = true
|
derive_more.workspace = true
|
||||||
envy = "0.4.2"
|
envy = "0.4.2"
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
google_ai.workspace = true
|
|
||||||
hex.workspace = true
|
hex.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
jsonwebtoken.workspace = true
|
jsonwebtoken.workspace = true
|
||||||
|
|
|
@ -3,7 +3,7 @@ mod connection_pool;
|
||||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||||
use crate::llm::LlmTokenClaims;
|
use crate::llm::LlmTokenClaims;
|
||||||
use crate::{
|
use crate::{
|
||||||
AppState, Config, Error, RateLimit, Result, auth,
|
AppState, Error, Result, auth,
|
||||||
db::{
|
db::{
|
||||||
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
|
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
|
||||||
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
|
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
|
||||||
|
@ -33,7 +33,6 @@ use chrono::Utc;
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||||
use core::fmt::{self, Debug, Formatter};
|
use core::fmt::{self, Debug, Formatter};
|
||||||
use http_client::HttpClient;
|
|
||||||
use reqwest_client::ReqwestClient;
|
use reqwest_client::ReqwestClient;
|
||||||
use rpc::proto::split_repository_update;
|
use rpc::proto::split_repository_update;
|
||||||
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
||||||
|
@ -132,7 +131,6 @@ struct Session {
|
||||||
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
|
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
|
||||||
app_state: Arc<AppState>,
|
app_state: Arc<AppState>,
|
||||||
supermaven_client: Option<Arc<SupermavenAdminApi>>,
|
supermaven_client: Option<Arc<SupermavenAdminApi>>,
|
||||||
http_client: Arc<dyn HttpClient>,
|
|
||||||
/// The GeoIP country code for the user.
|
/// The GeoIP country code for the user.
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
geoip_country_code: Option<String>,
|
geoip_country_code: Option<String>,
|
||||||
|
@ -425,17 +423,7 @@ impl Server {
|
||||||
.add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
|
.add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
|
||||||
.add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
|
.add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
|
||||||
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
||||||
.add_message_handler(update_context)
|
.add_message_handler(update_context);
|
||||||
.add_request_handler({
|
|
||||||
let app_state = app_state.clone();
|
|
||||||
move |request, response, session| {
|
|
||||||
let app_state = app_state.clone();
|
|
||||||
async move {
|
|
||||||
count_language_model_tokens(request, response, session, &app_state.config)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Arc::new(server)
|
Arc::new(server)
|
||||||
}
|
}
|
||||||
|
@ -764,7 +752,6 @@ impl Server {
|
||||||
peer: this.peer.clone(),
|
peer: this.peer.clone(),
|
||||||
connection_pool: this.connection_pool.clone(),
|
connection_pool: this.connection_pool.clone(),
|
||||||
app_state: this.app_state.clone(),
|
app_state: this.app_state.clone(),
|
||||||
http_client,
|
|
||||||
geoip_country_code,
|
geoip_country_code,
|
||||||
system_id,
|
system_id,
|
||||||
_executor: executor.clone(),
|
_executor: executor.clone(),
|
||||||
|
@ -3683,100 +3670,6 @@ async fn acknowledge_buffer_version(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn count_language_model_tokens(
|
|
||||||
request: proto::CountLanguageModelTokens,
|
|
||||||
response: Response<proto::CountLanguageModelTokens>,
|
|
||||||
session: Session,
|
|
||||||
config: &Config,
|
|
||||||
) -> Result<()> {
|
|
||||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
|
||||||
|
|
||||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
|
|
||||||
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
|
|
||||||
proto::Plan::Free | proto::Plan::ZedProTrial => {
|
|
||||||
Box::new(FreeCountLanguageModelTokensRateLimit)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
session
|
|
||||||
.app_state
|
|
||||||
.rate_limiter
|
|
||||||
.check(&*rate_limit, session.user_id())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
|
||||||
Some(proto::LanguageModelProvider::Google) => {
|
|
||||||
let api_key = config
|
|
||||||
.google_ai_api_key
|
|
||||||
.as_ref()
|
|
||||||
.context("no Google AI API key configured on the server")?;
|
|
||||||
google_ai::count_tokens(
|
|
||||||
session.http_client.as_ref(),
|
|
||||||
google_ai::API_URL,
|
|
||||||
api_key,
|
|
||||||
serde_json::from_str(&request.request)?,
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
}
|
|
||||||
_ => return Err(anyhow!("unsupported provider"))?,
|
|
||||||
};
|
|
||||||
|
|
||||||
response.send(proto::CountLanguageModelTokensResponse {
|
|
||||||
token_count: result.total_tokens as u32,
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ZedProCountLanguageModelTokensRateLimit;
|
|
||||||
|
|
||||||
impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
|
|
||||||
fn capacity(&self) -> usize {
|
|
||||||
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
|
|
||||||
.ok()
|
|
||||||
.and_then(|v| v.parse().ok())
|
|
||||||
.unwrap_or(600) // Picked arbitrarily
|
|
||||||
}
|
|
||||||
|
|
||||||
fn refill_duration(&self) -> chrono::Duration {
|
|
||||||
chrono::Duration::hours(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn db_name(&self) -> &'static str {
|
|
||||||
"zed-pro:count-language-model-tokens"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct FreeCountLanguageModelTokensRateLimit;
|
|
||||||
|
|
||||||
impl RateLimit for FreeCountLanguageModelTokensRateLimit {
|
|
||||||
fn capacity(&self) -> usize {
|
|
||||||
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
|
|
||||||
.ok()
|
|
||||||
.and_then(|v| v.parse().ok())
|
|
||||||
.unwrap_or(600 / 10) // Picked arbitrarily
|
|
||||||
}
|
|
||||||
|
|
||||||
fn refill_duration(&self) -> chrono::Duration {
|
|
||||||
chrono::Duration::hours(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn db_name(&self) -> &'static str {
|
|
||||||
"free:count-language-model-tokens"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This is leftover from before the LLM service.
|
|
||||||
///
|
|
||||||
/// The endpoints protected by this check will be moved there eventually.
|
|
||||||
async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
|
|
||||||
if session.is_staff() {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("permission denied"))?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a Supermaven API key for the user
|
/// Get a Supermaven API key for the user
|
||||||
async fn get_supermaven_api_key(
|
async fn get_supermaven_api_key(
|
||||||
_request: proto::GetSupermavenApiKey,
|
_request: proto::GetSupermavenApiKey,
|
||||||
|
|
|
@ -686,24 +686,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
match self.model.clone() {
|
match self.model.clone() {
|
||||||
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
|
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
|
||||||
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
||||||
CloudModel::Google(model) => {
|
CloudModel::Google(_model) => async move { Ok(0) }.boxed(),
|
||||||
let client = self.client.clone();
|
|
||||||
let request = into_google(request, model.id().into());
|
|
||||||
let request = google_ai::CountTokensRequest {
|
|
||||||
contents: request.contents,
|
|
||||||
};
|
|
||||||
async move {
|
|
||||||
let request = serde_json::to_string(&request)?;
|
|
||||||
let response = client
|
|
||||||
.request(proto::CountLanguageModelTokens {
|
|
||||||
provider: proto::LanguageModelProvider::Google as i32,
|
|
||||||
request,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
Ok(response.token_count as usize)
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -172,19 +172,3 @@ enum LanguageModelRole {
|
||||||
LanguageModelSystem = 2;
|
LanguageModelSystem = 2;
|
||||||
reserved 3;
|
reserved 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message CountLanguageModelTokens {
|
|
||||||
LanguageModelProvider provider = 1;
|
|
||||||
string request = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message CountLanguageModelTokensResponse {
|
|
||||||
uint32 token_count = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum LanguageModelProvider {
|
|
||||||
Anthropic = 0;
|
|
||||||
OpenAI = 1;
|
|
||||||
Google = 2;
|
|
||||||
Zed = 3;
|
|
||||||
}
|
|
||||||
|
|
|
@ -206,9 +206,6 @@ message Envelope {
|
||||||
GetImplementation get_implementation = 162;
|
GetImplementation get_implementation = 162;
|
||||||
GetImplementationResponse get_implementation_response = 163;
|
GetImplementationResponse get_implementation_response = 163;
|
||||||
|
|
||||||
CountLanguageModelTokens count_language_model_tokens = 230;
|
|
||||||
CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
|
|
||||||
|
|
||||||
UpdateChannelMessage update_channel_message = 170;
|
UpdateChannelMessage update_channel_message = 170;
|
||||||
ChannelMessageUpdate channel_message_update = 171;
|
ChannelMessageUpdate channel_message_update = 171;
|
||||||
|
|
||||||
|
@ -397,6 +394,7 @@ message Envelope {
|
||||||
reserved 205 to 206;
|
reserved 205 to 206;
|
||||||
reserved 221;
|
reserved 221;
|
||||||
reserved 224 to 229;
|
reserved 224 to 229;
|
||||||
|
reserved 230 to 231;
|
||||||
reserved 246;
|
reserved 246;
|
||||||
reserved 270;
|
reserved 270;
|
||||||
reserved 247 to 254;
|
reserved 247 to 254;
|
||||||
|
|
|
@ -50,8 +50,6 @@ messages!(
|
||||||
(CloseBuffer, Foreground),
|
(CloseBuffer, Foreground),
|
||||||
(Commit, Background),
|
(Commit, Background),
|
||||||
(CopyProjectEntry, Foreground),
|
(CopyProjectEntry, Foreground),
|
||||||
(CountLanguageModelTokens, Background),
|
|
||||||
(CountLanguageModelTokensResponse, Background),
|
|
||||||
(CreateBufferForPeer, Foreground),
|
(CreateBufferForPeer, Foreground),
|
||||||
(CreateChannel, Foreground),
|
(CreateChannel, Foreground),
|
||||||
(CreateChannelResponse, Foreground),
|
(CreateChannelResponse, Foreground),
|
||||||
|
@ -374,7 +372,6 @@ request_messages!(
|
||||||
(PerformRename, PerformRenameResponse),
|
(PerformRename, PerformRenameResponse),
|
||||||
(Ping, Ack),
|
(Ping, Ack),
|
||||||
(PrepareRename, PrepareRenameResponse),
|
(PrepareRename, PrepareRenameResponse),
|
||||||
(CountLanguageModelTokens, CountLanguageModelTokensResponse),
|
|
||||||
(RefreshInlayHints, Ack),
|
(RefreshInlayHints, Ack),
|
||||||
(RefreshCodeLens, Ack),
|
(RefreshCodeLens, Ack),
|
||||||
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
|
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue