Send llm events to snowflake too (#21091)

Closes #ISSUE

Release Notes:

- N/A
This commit is contained in:
Conrad Irwin 2024-11-22 20:40:39 -07:00 committed by GitHub
parent 5766afe710
commit 984bb192ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 183 additions and 25 deletions

View file

@ -1067,6 +1067,8 @@ impl Client {
let proxy = http.proxy().cloned(); let proxy = http.proxy().cloned();
let credentials = credentials.clone(); let credentials = credentials.clone();
let rpc_url = self.rpc_url(http, release_channel); let rpc_url = self.rpc_url(http, release_channel);
let system_id = self.telemetry.system_id();
let metrics_id = self.telemetry.metrics_id();
cx.background_executor().spawn(async move { cx.background_executor().spawn(async move {
use HttpOrHttps::*; use HttpOrHttps::*;
@ -1118,6 +1120,12 @@ impl Client {
"x-zed-release-channel", "x-zed-release-channel",
HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?, HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
); );
if let Some(system_id) = system_id {
request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
}
if let Some(metrics_id) = metrics_id {
request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
}
match url_scheme { match url_scheme {
Https => { Https => {

View file

@ -533,6 +533,10 @@ impl Telemetry {
self.state.lock().metrics_id.clone() self.state.lock().metrics_id.clone()
} }
pub fn system_id(self: &Arc<Self>) -> Option<Arc<str>> {
self.state.lock().system_id.clone()
}
pub fn installation_id(self: &Arc<Self>) -> Option<Arc<str>> { pub fn installation_id(self: &Arc<Self>) -> Option<Arc<str>> {
self.state.lock().installation_id.clone() self.state.lock().installation_id.clone()
} }

View file

@ -61,6 +61,39 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
} }
} }
pub struct SystemIdHeader(String);
impl Header for SystemIdHeader {
fn name() -> &'static HeaderName {
static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
}
fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i axum::http::HeaderValue>,
{
let system_id = values
.next()
.ok_or_else(axum::headers::Error::invalid)?
.to_str()
.map_err(|_| axum::headers::Error::invalid())?;
Ok(Self(system_id.to_string()))
}
fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
unimplemented!()
}
}
impl std::fmt::Display for SystemIdHeader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> { pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
Router::new() Router::new()
.route("/user", get(get_authenticated_user)) .route("/user", get(get_authenticated_user))

View file

@ -1578,8 +1578,8 @@ fn for_snowflake(
}) })
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Debug)]
struct SnowflakeRow { pub struct SnowflakeRow {
pub time: chrono::DateTime<chrono::Utc>, pub time: chrono::DateTime<chrono::Utc>,
pub user_id: Option<String>, pub user_id: Option<String>,
pub device_id: Option<String>, pub device_id: Option<String>,
@ -1588,3 +1588,42 @@ struct SnowflakeRow {
pub user_properties: Option<serde_json::Value>, pub user_properties: Option<serde_json::Value>,
pub insert_id: Option<String>, pub insert_id: Option<String>,
} }
impl SnowflakeRow {
pub fn new(
event_type: impl Into<String>,
metrics_id: Option<Uuid>,
is_staff: bool,
system_id: Option<String>,
event_properties: serde_json::Value,
) -> Self {
Self {
time: chrono::Utc::now(),
event_type: event_type.into(),
device_id: system_id,
user_id: metrics_id.map(|id| id.to_string()),
insert_id: Some(uuid::Uuid::new_v4().to_string()),
event_properties,
user_properties: Some(json!({"is_staff": is_staff})),
}
}
pub async fn write(
self,
client: &Option<aws_sdk_kinesis::Client>,
stream: &Option<String>,
) -> anyhow::Result<()> {
let Some((client, stream)) = client.as_ref().zip(stream.as_ref()) else {
return Ok(());
};
let row = serde_json::to_vec(&self)?;
client
.put_record()
.stream_name(stream)
.partition_key(&self.user_id.unwrap_or_default())
.data(row.into())
.send()
.await?;
Ok(())
}
}

View file

@ -1,3 +1,5 @@
use serde::Serialize;
/// A number of cents. /// A number of cents.
#[derive( #[derive(
Debug, Debug,
@ -12,6 +14,7 @@
derive_more::AddAssign, derive_more::AddAssign,
derive_more::Sub, derive_more::Sub,
derive_more::SubAssign, derive_more::SubAssign,
Serialize,
)] )]
pub struct Cents(pub u32); pub struct Cents(pub u32);

View file

@ -3,9 +3,11 @@ pub mod db;
mod telemetry; mod telemetry;
mod token; mod token;
use crate::api::events::SnowflakeRow;
use crate::api::CloudflareIpCountryHeader;
use crate::build_kinesis_client;
use crate::{ use crate::{
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents, build_clickhouse_client, db::UserId, executor::Executor, Cents, Config, Error, Result,
Config, Error, Result,
}; };
use anyhow::{anyhow, Context as _}; use anyhow::{anyhow, Context as _};
use authorization::authorize_access_to_language_model; use authorization::authorize_access_to_language_model;
@ -28,6 +30,7 @@ use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
}; };
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME}; use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use serde_json::json;
use std::{ use std::{
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
@ -45,6 +48,7 @@ pub struct LlmState {
pub executor: Executor, pub executor: Executor,
pub db: Arc<LlmDatabase>, pub db: Arc<LlmDatabase>,
pub http_client: ReqwestClient, pub http_client: ReqwestClient,
pub kinesis_client: Option<aws_sdk_kinesis::Client>,
pub clickhouse_client: Option<clickhouse::Client>, pub clickhouse_client: Option<clickhouse::Client>,
active_user_count_by_model: active_user_count_by_model:
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>, RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
@ -77,6 +81,11 @@ impl LlmState {
executor, executor,
db, db,
http_client, http_client,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()
} else {
None
},
clickhouse_client: config clickhouse_client: config
.clickhouse_url .clickhouse_url
.as_ref() .as_ref()
@ -521,25 +530,50 @@ async fn check_usage_limit(
UsageMeasure::TokensPerDay => "tokens_per_day", UsageMeasure::TokensPerDay => "tokens_per_day",
}; };
if let Some(client) = state.clickhouse_client.as_ref() { tracing::info!(
tracing::info!( target: "user rate limit",
target: "user rate limit", user_id = claims.user_id,
user_id = claims.user_id, login = claims.github_user_login,
login = claims.github_user_login, authn.jti = claims.jti,
authn.jti = claims.jti, is_staff = claims.is_staff,
is_staff = claims.is_staff, provider = provider.to_string(),
provider = provider.to_string(), model = model.name,
model = model.name, requests_this_minute = usage.requests_this_minute,
requests_this_minute = usage.requests_this_minute, tokens_this_minute = usage.tokens_this_minute,
tokens_this_minute = usage.tokens_this_minute, tokens_this_day = usage.tokens_this_day,
tokens_this_day = usage.tokens_this_day, users_in_recent_minutes = users_in_recent_minutes,
users_in_recent_minutes = users_in_recent_minutes, users_in_recent_days = users_in_recent_days,
users_in_recent_days = users_in_recent_days, max_requests_per_minute = per_user_max_requests_per_minute,
max_requests_per_minute = per_user_max_requests_per_minute, max_tokens_per_minute = per_user_max_tokens_per_minute,
max_tokens_per_minute = per_user_max_tokens_per_minute, max_tokens_per_day = per_user_max_tokens_per_day,
max_tokens_per_day = per_user_max_tokens_per_day, );
);
SnowflakeRow::new(
"Language Model Rate Limited",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
json!({
"usage": usage,
"users_in_recent_minutes": users_in_recent_minutes,
"users_in_recent_days": users_in_recent_days,
"max_requests_per_minute": per_user_max_requests_per_minute,
"max_tokens_per_minute": per_user_max_tokens_per_minute,
"max_tokens_per_day": per_user_max_tokens_per_day,
"plan": match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
"model": model.name.clone(),
"provider": provider.to_string(),
"usage_measure": resource.to_string(),
}),
)
.write(&state.kinesis_client, &state.config.kinesis_stream)
.await
.log_err();
if let Some(client) = state.clickhouse_client.as_ref() {
report_llm_rate_limit( report_llm_rate_limit(
client, client,
LlmRateLimitEventRow { LlmRateLimitEventRow {
@ -652,6 +686,27 @@ impl<S> Drop for TokenCountingStream<S> {
tokens_this_minute = usage.tokens_this_minute, tokens_this_minute = usage.tokens_this_minute,
); );
let properties = json!({
"plan": match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
"model": model,
"provider": provider,
"usage": usage,
"tokens": tokens
});
SnowflakeRow::new(
"Language Model Used",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
properties,
)
.write(&state.kinesis_client, &state.config.kinesis_stream)
.await
.log_err();
if let Some(clickhouse_client) = state.clickhouse_client.as_ref() { if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
report_llm_usage( report_llm_usage(
clickhouse_client, clickhouse_client,

View file

@ -9,7 +9,7 @@ use strum::IntoEnumIterator as _;
use super::*; use super::*;
#[derive(Debug, PartialEq, Clone, Copy, Default)] #[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
pub struct TokenUsage { pub struct TokenUsage {
pub input: usize, pub input: usize,
pub input_cache_creation: usize, pub input_cache_creation: usize,
@ -23,7 +23,7 @@ impl TokenUsage {
} }
} }
#[derive(Debug, PartialEq, Clone, Copy)] #[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
pub struct Usage { pub struct Usage {
pub requests_this_minute: usize, pub requests_this_minute: usize,
pub tokens_this_minute: usize, pub tokens_this_minute: usize,

View file

@ -8,6 +8,7 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::time::Duration; use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use uuid::Uuid;
#[derive(Clone, Debug, Default, Serialize, Deserialize)] #[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -16,6 +17,10 @@ pub struct LlmTokenClaims {
pub exp: u64, pub exp: u64,
pub jti: String, pub jti: String,
pub user_id: u64, pub user_id: u64,
#[serde(default)]
pub system_id: Option<String>,
#[serde(default)]
pub metrics_id: Option<Uuid>,
pub github_user_login: String, pub github_user_login: String,
pub is_staff: bool, pub is_staff: bool,
pub has_llm_closed_beta_feature_flag: bool, pub has_llm_closed_beta_feature_flag: bool,
@ -36,6 +41,7 @@ impl LlmTokenClaims {
has_llm_closed_beta_feature_flag: bool, has_llm_closed_beta_feature_flag: bool,
has_llm_subscription: bool, has_llm_subscription: bool,
plan: rpc::proto::Plan, plan: rpc::proto::Plan,
system_id: Option<String>,
config: &Config, config: &Config,
) -> Result<String> { ) -> Result<String> {
let secret = config let secret = config
@ -49,6 +55,8 @@ impl LlmTokenClaims {
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(), jti: uuid::Uuid::new_v4().to_string(),
user_id: user.id.to_proto(), user_id: user.id.to_proto(),
system_id,
metrics_id: Some(user.metrics_id),
github_user_login: user.github_login.clone(), github_user_login: user.github_login.clone(),
is_staff, is_staff,
has_llm_closed_beta_feature_flag, has_llm_closed_beta_feature_flag,

View file

@ -1,6 +1,6 @@
mod connection_pool; mod connection_pool;
use crate::api::CloudflareIpCountryHeader; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::llm::LlmTokenClaims; use crate::llm::LlmTokenClaims;
use crate::{ use crate::{
auth, auth,
@ -137,6 +137,7 @@ struct Session {
/// The GeoIP country code for the user. /// The GeoIP country code for the user.
#[allow(unused)] #[allow(unused)]
geoip_country_code: Option<String>, geoip_country_code: Option<String>,
system_id: Option<String>,
_executor: Executor, _executor: Executor,
} }
@ -682,6 +683,7 @@ impl Server {
principal: Principal, principal: Principal,
zed_version: ZedVersion, zed_version: ZedVersion,
geoip_country_code: Option<String>, geoip_country_code: Option<String>,
system_id: Option<String>,
send_connection_id: Option<oneshot::Sender<ConnectionId>>, send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor, executor: Executor,
) -> impl Future<Output = ()> { ) -> impl Future<Output = ()> {
@ -737,6 +739,7 @@ impl Server {
app_state: this.app_state.clone(), app_state: this.app_state.clone(),
http_client, http_client,
geoip_country_code, geoip_country_code,
system_id,
_executor: executor.clone(), _executor: executor.clone(),
supermaven_client, supermaven_client,
}; };
@ -1056,6 +1059,7 @@ pub fn routes(server: Arc<Server>) -> Router<(), Body> {
.layer(Extension(server)) .layer(Extension(server))
} }
#[allow(clippy::too_many_arguments)]
pub async fn handle_websocket_request( pub async fn handle_websocket_request(
TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>, TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
app_version_header: Option<TypedHeader<AppVersionHeader>>, app_version_header: Option<TypedHeader<AppVersionHeader>>,
@ -1063,6 +1067,7 @@ pub async fn handle_websocket_request(
Extension(server): Extension<Arc<Server>>, Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>, Extension(principal): Extension<Principal>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>, country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
system_id_header: Option<TypedHeader<SystemIdHeader>>,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
) -> axum::response::Response { ) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION { if protocol_version != rpc::PROTOCOL_VERSION {
@ -1104,6 +1109,7 @@ pub async fn handle_websocket_request(
principal, principal,
version, version,
country_code_header.map(|header| header.to_string()), country_code_header.map(|header| header.to_string()),
system_id_header.map(|header| header.to_string()),
None, None,
Executor::Production, Executor::Production,
) )
@ -4053,6 +4059,7 @@ async fn get_llm_api_token(
has_llm_closed_beta_feature_flag, has_llm_closed_beta_feature_flag,
has_llm_subscription, has_llm_subscription,
session.current_plan(&db).await?, session.current_plan(&db).await?,
session.system_id.clone(),
&session.app_state.config, &session.app_state.config,
)?; )?;
response.send(proto::GetLlmTokenResponse { token })?; response.send(proto::GetLlmTokenResponse { token })?;

View file

@ -244,6 +244,7 @@ impl TestServer {
Principal::User(user), Principal::User(user),
ZedVersion(SemanticVersion::new(1, 0, 0)), ZedVersion(SemanticVersion::new(1, 0, 0)),
None, None,
None,
Some(connection_id_tx), Some(connection_id_tx),
Executor::Deterministic(cx.background_executor().clone()), Executor::Deterministic(cx.background_executor().clone()),
)) ))