diff --git a/Cargo.lock b/Cargo.lock index 1b4e88b2a6..0cae49fa66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2998,7 +2998,6 @@ dependencies = [ "git", "git_hosting_providers", "git_ui", - "google_ai", "gpui", "gpui_tokio", "hex", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index d6df4be6df..0fc1422129 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -34,7 +34,6 @@ dashmap.workspace = true derive_more.workspace = true envy = "0.4.2" futures.workspace = true -google_ai.workspace = true hex.workspace = true http_client.workspace = true jsonwebtoken.workspace = true diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 76f5f94641..e2dacc5389 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -3,7 +3,7 @@ mod connection_pool; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; use crate::llm::LlmTokenClaims; use crate::{ - AppState, Config, Error, RateLimit, Result, auth, + AppState, Error, Result, auth, db::{ self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, @@ -33,7 +33,6 @@ use chrono::Utc; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; -use http_client::HttpClient; use reqwest_client::ReqwestClient; use rpc::proto::split_repository_update; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; @@ -132,7 +131,6 @@ struct Session { connection_pool: Arc>, app_state: Arc, supermaven_client: Option>, - http_client: Arc, /// The GeoIP country code for the user. #[allow(unused)] geoip_country_code: Option, @@ -425,17 +423,7 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_message_handler(broadcast_project_message_from_host::) - .add_message_handler(update_context) - .add_request_handler({ - let app_state = app_state.clone(); - move |request, response, session| { - let app_state = app_state.clone(); - async move { - count_language_model_tokens(request, response, session, &app_state.config) - .await - } - } - }); + .add_message_handler(update_context); Arc::new(server) } @@ -764,7 +752,6 @@ impl Server { peer: this.peer.clone(), connection_pool: this.connection_pool.clone(), app_state: this.app_state.clone(), - http_client, geoip_country_code, system_id, _executor: executor.clone(), @@ -3683,100 +3670,6 @@ async fn acknowledge_buffer_version( Ok(()) } -async fn count_language_model_tokens( - request: proto::CountLanguageModelTokens, - response: Response, - session: Session, - config: &Config, -) -> Result<()> { - authorize_access_to_legacy_llm_endpoints(&session).await?; - - let rate_limit: Box = match session.current_plan(&session.db().await).await? { - proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit), - proto::Plan::Free | proto::Plan::ZedProTrial => { - Box::new(FreeCountLanguageModelTokensRateLimit) - } - }; - - session - .app_state - .rate_limiter - .check(&*rate_limit, session.user_id()) - .await?; - - let result = match proto::LanguageModelProvider::from_i32(request.provider) { - Some(proto::LanguageModelProvider::Google) => { - let api_key = config - .google_ai_api_key - .as_ref() - .context("no Google AI API key configured on the server")?; - google_ai::count_tokens( - session.http_client.as_ref(), - google_ai::API_URL, - api_key, - serde_json::from_str(&request.request)?, - ) - .await? - } - _ => return Err(anyhow!("unsupported provider"))?, - }; - - response.send(proto::CountLanguageModelTokensResponse { - token_count: result.total_tokens as u32, - })?; - - Ok(()) -} - -struct ZedProCountLanguageModelTokensRateLimit; - -impl RateLimit for ZedProCountLanguageModelTokensRateLimit { - fn capacity(&self) -> usize { - std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(600) // Picked arbitrarily - } - - fn refill_duration(&self) -> chrono::Duration { - chrono::Duration::hours(1) - } - - fn db_name(&self) -> &'static str { - "zed-pro:count-language-model-tokens" - } -} - -struct FreeCountLanguageModelTokensRateLimit; - -impl RateLimit for FreeCountLanguageModelTokensRateLimit { - fn capacity(&self) -> usize { - std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(600 / 10) // Picked arbitrarily - } - - fn refill_duration(&self) -> chrono::Duration { - chrono::Duration::hours(1) - } - - fn db_name(&self) -> &'static str { - "free:count-language-model-tokens" - } -} - -/// This is leftover from before the LLM service. -/// -/// The endpoints protected by this check will be moved there eventually. -async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> { - if session.is_staff() { - Ok(()) - } else { - Err(anyhow!("permission denied"))? - } -} - /// Get a Supermaven API key for the user async fn get_supermaven_api_key( _request: proto::GetSupermavenApiKey, diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index bf6fc5a4e3..b8bc86d406 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -686,24 +686,7 @@ impl LanguageModel for CloudLanguageModel { match self.model.clone() { CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx), CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx), - CloudModel::Google(model) => { - let client = self.client.clone(); - let request = into_google(request, model.id().into()); - let request = google_ai::CountTokensRequest { - contents: request.contents, - }; - async move { - let request = serde_json::to_string(&request)?; - let response = client - .request(proto::CountLanguageModelTokens { - provider: proto::LanguageModelProvider::Google as i32, - request, - }) - .await?; - Ok(response.token_count as usize) - } - .boxed() - } + CloudModel::Google(_model) => async move { Ok(0) }.boxed(), } } diff --git a/crates/proto/proto/ai.proto b/crates/proto/proto/ai.proto index c39345c3c2..67c2224387 100644 --- a/crates/proto/proto/ai.proto +++ b/crates/proto/proto/ai.proto @@ -172,19 +172,3 @@ enum LanguageModelRole { LanguageModelSystem = 2; reserved 3; } - -message CountLanguageModelTokens { - LanguageModelProvider provider = 1; - string request = 2; -} - -message CountLanguageModelTokensResponse { - uint32 token_count = 1; -} - -enum LanguageModelProvider { - Anthropic = 0; - OpenAI = 1; - Google = 2; - Zed = 3; -} diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index ef9f46ac49..74db3ce539 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -206,9 +206,6 @@ message Envelope { GetImplementation get_implementation = 162; GetImplementationResponse get_implementation_response = 163; - CountLanguageModelTokens count_language_model_tokens = 230; - CountLanguageModelTokensResponse count_language_model_tokens_response = 231; - UpdateChannelMessage update_channel_message = 170; ChannelMessageUpdate channel_message_update = 171; @@ -397,6 +394,7 @@ message Envelope { reserved 205 to 206; reserved 221; reserved 224 to 229; + reserved 230 to 231; reserved 246; reserved 270; reserved 247 to 254; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 2f34bd2cd8..03e0bda101 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -50,8 +50,6 @@ messages!( (CloseBuffer, Foreground), (Commit, Background), (CopyProjectEntry, Foreground), - (CountLanguageModelTokens, Background), - (CountLanguageModelTokensResponse, Background), (CreateBufferForPeer, Foreground), (CreateChannel, Foreground), (CreateChannelResponse, Foreground), @@ -374,7 +372,6 @@ request_messages!( (PerformRename, PerformRenameResponse), (Ping, Ack), (PrepareRename, PrepareRenameResponse), - (CountLanguageModelTokens, CountLanguageModelTokensResponse), (RefreshInlayHints, Ack), (RefreshCodeLens, Ack), (RejoinChannelBuffers, RejoinChannelBuffersResponse),