diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index d379cb5aad..741d00d0e7 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -134,6 +134,11 @@ spec: secretKeyRef: name: anthropic key: api_key + - name: ANTHROPIC_STAFF_API_KEY + valueFrom: + secretKeyRef: + name: anthropic + key: staff_api_key - name: GOOGLE_AI_API_KEY valueFrom: secretKeyRef: diff --git a/crates/collab/migrations_llm/20240812184444_add_is_staff_to_usages.sql b/crates/collab/migrations_llm/20240812184444_add_is_staff_to_usages.sql new file mode 100644 index 0000000000..a50feb2e3f --- /dev/null +++ b/crates/collab/migrations_llm/20240812184444_add_is_staff_to_usages.sql @@ -0,0 +1 @@ +alter table usages add column is_staff boolean not null default false; diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 0463632e8f..9cae7713dc 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -166,6 +166,7 @@ pub struct Config { pub openai_api_key: Option>, pub google_ai_api_key: Option>, pub anthropic_api_key: Option>, + pub anthropic_staff_api_key: Option>, pub qwen2_7b_api_key: Option>, pub qwen2_7b_api_url: Option>, pub zed_client_checksum_seed: Option, @@ -216,6 +217,7 @@ impl Config { openai_api_key: None, google_ai_api_key: None, anthropic_api_key: None, + anthropic_staff_api_key: None, clickhouse_url: None, clickhouse_user: None, clickhouse_password: None, diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index fa74361941..4c249dcb4a 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -171,11 +171,19 @@ async fn perform_completion( let stream = match params.provider { LanguageModelProvider::Anthropic => { - let api_key = state - .config - .anthropic_api_key - .as_ref() - .context("no Anthropic AI API key configured on the server")?; + let api_key = if claims.is_staff { + state + .config + .anthropic_staff_api_key + .as_ref() + .context("no Anthropic AI staff API key configured on the server")? + } else { + state + .config + .anthropic_api_key + .as_ref() + .context("no Anthropic AI API key configured on the server")? + }; let mut request: anthropic::Request = serde_json::from_str(¶ms.provider_request.get())?; @@ -473,6 +481,7 @@ impl Drop for TokenCountingStream { .db .record_usage( claims.user_id as i32, + claims.is_staff, provider, &model, input_token_count, diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 3dc247ca12..e11adaf2a7 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -108,9 +108,11 @@ impl LlmDatabase { .await } + #[allow(clippy::too_many_arguments)] pub async fn record_usage( &self, user_id: i32, + is_staff: bool, provider: LanguageModelProvider, model_name: &str, input_token_count: usize, @@ -132,6 +134,7 @@ impl LlmDatabase { let requests_this_minute = self .update_usage_for_measure( user_id, + is_staff, model.id, &usages, UsageMeasure::RequestsPerMinute, @@ -143,6 +146,7 @@ impl LlmDatabase { let tokens_this_minute = self .update_usage_for_measure( user_id, + is_staff, model.id, &usages, UsageMeasure::TokensPerMinute, @@ -154,6 +158,7 @@ impl LlmDatabase { let tokens_this_day = self .update_usage_for_measure( user_id, + is_staff, model.id, &usages, UsageMeasure::TokensPerDay, @@ -165,6 +170,7 @@ impl LlmDatabase { let input_tokens_this_month = self .update_usage_for_measure( user_id, + is_staff, model.id, &usages, UsageMeasure::InputTokensPerMonth, @@ -176,6 +182,7 @@ impl LlmDatabase { let output_tokens_this_month = self .update_usage_for_measure( user_id, + is_staff, model.id, &usages, UsageMeasure::OutputTokensPerMonth, @@ -205,7 +212,11 @@ impl LlmDatabase { let day_since = now - Duration::days(5); let users_in_recent_minutes = usage::Entity::find() - .filter(usage::Column::Timestamp.gte(minute_since.naive_utc())) + .filter( + usage::Column::Timestamp + .gte(minute_since.naive_utc()) + .and(usage::Column::IsStaff.eq(false)), + ) .select_only() .column(usage::Column::UserId) .group_by(usage::Column::UserId) @@ -213,7 +224,11 @@ impl LlmDatabase { .await? as usize; let users_in_recent_days = usage::Entity::find() - .filter(usage::Column::Timestamp.gte(day_since.naive_utc())) + .filter( + usage::Column::Timestamp + .gte(day_since.naive_utc()) + .and(usage::Column::IsStaff.eq(false)), + ) .select_only() .column(usage::Column::UserId) .group_by(usage::Column::UserId) @@ -232,6 +247,7 @@ impl LlmDatabase { async fn update_usage_for_measure( &self, user_id: i32, + is_staff: bool, model_id: ModelId, usages: &[usage::Model], usage_measure: UsageMeasure, @@ -267,6 +283,7 @@ impl LlmDatabase { let mut model = usage::ActiveModel { user_id: ActiveValue::set(user_id), + is_staff: ActiveValue::set(is_staff), model_id: ActiveValue::set(model_id), measure_id: ActiveValue::set(measure_id), timestamp: ActiveValue::set(timestamp), diff --git a/crates/collab/src/llm/db/tables/usage.rs b/crates/collab/src/llm/db/tables/usage.rs index 5d131133c3..76f7e1b01d 100644 --- a/crates/collab/src/llm/db/tables/usage.rs +++ b/crates/collab/src/llm/db/tables/usage.rs @@ -15,6 +15,7 @@ pub struct Model { pub measure_id: UsageMeasureId, pub timestamp: DateTime, pub buckets: Vec, + pub is_staff: bool, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs index 8faad1caaf..336c3c9301 100644 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ b/crates/collab/src/llm/db/tests/usage_tests.rs @@ -29,12 +29,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { let user_id = 123; let now = t0; - db.record_usage(user_id, provider, model, 1000, 0, now) + db.record_usage(user_id, false, provider, model, 1000, 0, now) .await .unwrap(); let now = t0 + Duration::seconds(10); - db.record_usage(user_id, provider, model, 2000, 0, now) + db.record_usage(user_id, false, provider, model, 2000, 0, now) .await .unwrap(); @@ -66,7 +66,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { ); let now = t0 + Duration::seconds(60); - db.record_usage(user_id, provider, model, 3000, 0, now) + db.record_usage(user_id, false, provider, model, 3000, 0, now) .await .unwrap(); @@ -98,7 +98,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { } ); - db.record_usage(user_id, provider, model, 4000, 0, now) + db.record_usage(user_id, false, provider, model, 4000, 0, now) .await .unwrap(); diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index b420960122..16243e1ff2 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -666,6 +666,7 @@ impl TestServer { openai_api_key: None, google_ai_api_key: None, anthropic_api_key: None, + anthropic_staff_api_key: None, clickhouse_url: None, clickhouse_user: None, clickhouse_password: None,