Send llm events to snowflake too (#21091)
Closes #ISSUE Release Notes: - N/A
This commit is contained in:
parent
5766afe710
commit
984bb192ba
10 changed files with 183 additions and 25 deletions
|
@ -1067,6 +1067,8 @@ impl Client {
|
|||
let proxy = http.proxy().cloned();
|
||||
let credentials = credentials.clone();
|
||||
let rpc_url = self.rpc_url(http, release_channel);
|
||||
let system_id = self.telemetry.system_id();
|
||||
let metrics_id = self.telemetry.metrics_id();
|
||||
cx.background_executor().spawn(async move {
|
||||
use HttpOrHttps::*;
|
||||
|
||||
|
@ -1118,6 +1120,12 @@ impl Client {
|
|||
"x-zed-release-channel",
|
||||
HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
|
||||
);
|
||||
if let Some(system_id) = system_id {
|
||||
request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
|
||||
}
|
||||
if let Some(metrics_id) = metrics_id {
|
||||
request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
|
||||
}
|
||||
|
||||
match url_scheme {
|
||||
Https => {
|
||||
|
|
|
@ -533,6 +533,10 @@ impl Telemetry {
|
|||
self.state.lock().metrics_id.clone()
|
||||
}
|
||||
|
||||
pub fn system_id(self: &Arc<Self>) -> Option<Arc<str>> {
|
||||
self.state.lock().system_id.clone()
|
||||
}
|
||||
|
||||
pub fn installation_id(self: &Arc<Self>) -> Option<Arc<str>> {
|
||||
self.state.lock().installation_id.clone()
|
||||
}
|
||||
|
|
|
@ -61,6 +61,39 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct SystemIdHeader(String);
|
||||
|
||||
impl Header for SystemIdHeader {
|
||||
fn name() -> &'static HeaderName {
|
||||
static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
|
||||
SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
|
||||
}
|
||||
|
||||
fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
|
||||
where
|
||||
Self: Sized,
|
||||
I: Iterator<Item = &'i axum::http::HeaderValue>,
|
||||
{
|
||||
let system_id = values
|
||||
.next()
|
||||
.ok_or_else(axum::headers::Error::invalid)?
|
||||
.to_str()
|
||||
.map_err(|_| axum::headers::Error::invalid())?;
|
||||
|
||||
Ok(Self(system_id.to_string()))
|
||||
}
|
||||
|
||||
fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SystemIdHeader {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/user", get(get_authenticated_user))
|
||||
|
|
|
@ -1578,8 +1578,8 @@ fn for_snowflake(
|
|||
})
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SnowflakeRow {
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SnowflakeRow {
|
||||
pub time: chrono::DateTime<chrono::Utc>,
|
||||
pub user_id: Option<String>,
|
||||
pub device_id: Option<String>,
|
||||
|
@ -1588,3 +1588,42 @@ struct SnowflakeRow {
|
|||
pub user_properties: Option<serde_json::Value>,
|
||||
pub insert_id: Option<String>,
|
||||
}
|
||||
|
||||
impl SnowflakeRow {
|
||||
pub fn new(
|
||||
event_type: impl Into<String>,
|
||||
metrics_id: Option<Uuid>,
|
||||
is_staff: bool,
|
||||
system_id: Option<String>,
|
||||
event_properties: serde_json::Value,
|
||||
) -> Self {
|
||||
Self {
|
||||
time: chrono::Utc::now(),
|
||||
event_type: event_type.into(),
|
||||
device_id: system_id,
|
||||
user_id: metrics_id.map(|id| id.to_string()),
|
||||
insert_id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
event_properties,
|
||||
user_properties: Some(json!({"is_staff": is_staff})),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write(
|
||||
self,
|
||||
client: &Option<aws_sdk_kinesis::Client>,
|
||||
stream: &Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some((client, stream)) = client.as_ref().zip(stream.as_ref()) else {
|
||||
return Ok(());
|
||||
};
|
||||
let row = serde_json::to_vec(&self)?;
|
||||
client
|
||||
.put_record()
|
||||
.stream_name(stream)
|
||||
.partition_key(&self.user_id.unwrap_or_default())
|
||||
.data(row.into())
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use serde::Serialize;
|
||||
|
||||
/// A number of cents.
|
||||
#[derive(
|
||||
Debug,
|
||||
|
@ -12,6 +14,7 @@
|
|||
derive_more::AddAssign,
|
||||
derive_more::Sub,
|
||||
derive_more::SubAssign,
|
||||
Serialize,
|
||||
)]
|
||||
pub struct Cents(pub u32);
|
||||
|
||||
|
|
|
@ -3,9 +3,11 @@ pub mod db;
|
|||
mod telemetry;
|
||||
mod token;
|
||||
|
||||
use crate::api::events::SnowflakeRow;
|
||||
use crate::api::CloudflareIpCountryHeader;
|
||||
use crate::build_kinesis_client;
|
||||
use crate::{
|
||||
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents,
|
||||
Config, Error, Result,
|
||||
build_clickhouse_client, db::UserId, executor::Executor, Cents, Config, Error, Result,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
use authorization::authorize_access_to_language_model;
|
||||
|
@ -28,6 +30,7 @@ use rpc::{
|
|||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
};
|
||||
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
|
||||
use serde_json::json;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
|
@ -45,6 +48,7 @@ pub struct LlmState {
|
|||
pub executor: Executor,
|
||||
pub db: Arc<LlmDatabase>,
|
||||
pub http_client: ReqwestClient,
|
||||
pub kinesis_client: Option<aws_sdk_kinesis::Client>,
|
||||
pub clickhouse_client: Option<clickhouse::Client>,
|
||||
active_user_count_by_model:
|
||||
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
|
||||
|
@ -77,6 +81,11 @@ impl LlmState {
|
|||
executor,
|
||||
db,
|
||||
http_client,
|
||||
kinesis_client: if config.kinesis_access_key.is_some() {
|
||||
build_kinesis_client(&config).await.log_err()
|
||||
} else {
|
||||
None
|
||||
},
|
||||
clickhouse_client: config
|
||||
.clickhouse_url
|
||||
.as_ref()
|
||||
|
@ -521,25 +530,50 @@ async fn check_usage_limit(
|
|||
UsageMeasure::TokensPerDay => "tokens_per_day",
|
||||
};
|
||||
|
||||
if let Some(client) = state.clickhouse_client.as_ref() {
|
||||
tracing::info!(
|
||||
target: "user rate limit",
|
||||
user_id = claims.user_id,
|
||||
login = claims.github_user_login,
|
||||
authn.jti = claims.jti,
|
||||
is_staff = claims.is_staff,
|
||||
provider = provider.to_string(),
|
||||
model = model.name,
|
||||
requests_this_minute = usage.requests_this_minute,
|
||||
tokens_this_minute = usage.tokens_this_minute,
|
||||
tokens_this_day = usage.tokens_this_day,
|
||||
users_in_recent_minutes = users_in_recent_minutes,
|
||||
users_in_recent_days = users_in_recent_days,
|
||||
max_requests_per_minute = per_user_max_requests_per_minute,
|
||||
max_tokens_per_minute = per_user_max_tokens_per_minute,
|
||||
max_tokens_per_day = per_user_max_tokens_per_day,
|
||||
);
|
||||
tracing::info!(
|
||||
target: "user rate limit",
|
||||
user_id = claims.user_id,
|
||||
login = claims.github_user_login,
|
||||
authn.jti = claims.jti,
|
||||
is_staff = claims.is_staff,
|
||||
provider = provider.to_string(),
|
||||
model = model.name,
|
||||
requests_this_minute = usage.requests_this_minute,
|
||||
tokens_this_minute = usage.tokens_this_minute,
|
||||
tokens_this_day = usage.tokens_this_day,
|
||||
users_in_recent_minutes = users_in_recent_minutes,
|
||||
users_in_recent_days = users_in_recent_days,
|
||||
max_requests_per_minute = per_user_max_requests_per_minute,
|
||||
max_tokens_per_minute = per_user_max_tokens_per_minute,
|
||||
max_tokens_per_day = per_user_max_tokens_per_day,
|
||||
);
|
||||
|
||||
SnowflakeRow::new(
|
||||
"Language Model Rate Limited",
|
||||
claims.metrics_id,
|
||||
claims.is_staff,
|
||||
claims.system_id.clone(),
|
||||
json!({
|
||||
"usage": usage,
|
||||
"users_in_recent_minutes": users_in_recent_minutes,
|
||||
"users_in_recent_days": users_in_recent_days,
|
||||
"max_requests_per_minute": per_user_max_requests_per_minute,
|
||||
"max_tokens_per_minute": per_user_max_tokens_per_minute,
|
||||
"max_tokens_per_day": per_user_max_tokens_per_day,
|
||||
"plan": match claims.plan {
|
||||
Plan::Free => "free".to_string(),
|
||||
Plan::ZedPro => "zed_pro".to_string(),
|
||||
},
|
||||
"model": model.name.clone(),
|
||||
"provider": provider.to_string(),
|
||||
"usage_measure": resource.to_string(),
|
||||
}),
|
||||
)
|
||||
.write(&state.kinesis_client, &state.config.kinesis_stream)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
if let Some(client) = state.clickhouse_client.as_ref() {
|
||||
report_llm_rate_limit(
|
||||
client,
|
||||
LlmRateLimitEventRow {
|
||||
|
@ -652,6 +686,27 @@ impl<S> Drop for TokenCountingStream<S> {
|
|||
tokens_this_minute = usage.tokens_this_minute,
|
||||
);
|
||||
|
||||
let properties = json!({
|
||||
"plan": match claims.plan {
|
||||
Plan::Free => "free".to_string(),
|
||||
Plan::ZedPro => "zed_pro".to_string(),
|
||||
},
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
"usage": usage,
|
||||
"tokens": tokens
|
||||
});
|
||||
SnowflakeRow::new(
|
||||
"Language Model Used",
|
||||
claims.metrics_id,
|
||||
claims.is_staff,
|
||||
claims.system_id.clone(),
|
||||
properties,
|
||||
)
|
||||
.write(&state.kinesis_client, &state.config.kinesis_stream)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
|
||||
report_llm_usage(
|
||||
clickhouse_client,
|
||||
|
|
|
@ -9,7 +9,7 @@ use strum::IntoEnumIterator as _;
|
|||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Default)]
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
|
||||
pub struct TokenUsage {
|
||||
pub input: usize,
|
||||
pub input_cache_creation: usize,
|
||||
|
@ -23,7 +23,7 @@ impl TokenUsage {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
|
||||
pub struct Usage {
|
||||
pub requests_this_minute: usize,
|
||||
pub tokens_this_minute: usize,
|
||||
|
|
|
@ -8,6 +8,7 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
@ -16,6 +17,10 @@ pub struct LlmTokenClaims {
|
|||
pub exp: u64,
|
||||
pub jti: String,
|
||||
pub user_id: u64,
|
||||
#[serde(default)]
|
||||
pub system_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub metrics_id: Option<Uuid>,
|
||||
pub github_user_login: String,
|
||||
pub is_staff: bool,
|
||||
pub has_llm_closed_beta_feature_flag: bool,
|
||||
|
@ -36,6 +41,7 @@ impl LlmTokenClaims {
|
|||
has_llm_closed_beta_feature_flag: bool,
|
||||
has_llm_subscription: bool,
|
||||
plan: rpc::proto::Plan,
|
||||
system_id: Option<String>,
|
||||
config: &Config,
|
||||
) -> Result<String> {
|
||||
let secret = config
|
||||
|
@ -49,6 +55,8 @@ impl LlmTokenClaims {
|
|||
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user.id.to_proto(),
|
||||
system_id,
|
||||
metrics_id: Some(user.metrics_id),
|
||||
github_user_login: user.github_login.clone(),
|
||||
is_staff,
|
||||
has_llm_closed_beta_feature_flag,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
mod connection_pool;
|
||||
|
||||
use crate::api::CloudflareIpCountryHeader;
|
||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||
use crate::llm::LlmTokenClaims;
|
||||
use crate::{
|
||||
auth,
|
||||
|
@ -137,6 +137,7 @@ struct Session {
|
|||
/// The GeoIP country code for the user.
|
||||
#[allow(unused)]
|
||||
geoip_country_code: Option<String>,
|
||||
system_id: Option<String>,
|
||||
_executor: Executor,
|
||||
}
|
||||
|
||||
|
@ -682,6 +683,7 @@ impl Server {
|
|||
principal: Principal,
|
||||
zed_version: ZedVersion,
|
||||
geoip_country_code: Option<String>,
|
||||
system_id: Option<String>,
|
||||
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
|
||||
executor: Executor,
|
||||
) -> impl Future<Output = ()> {
|
||||
|
@ -737,6 +739,7 @@ impl Server {
|
|||
app_state: this.app_state.clone(),
|
||||
http_client,
|
||||
geoip_country_code,
|
||||
system_id,
|
||||
_executor: executor.clone(),
|
||||
supermaven_client,
|
||||
};
|
||||
|
@ -1056,6 +1059,7 @@ pub fn routes(server: Arc<Server>) -> Router<(), Body> {
|
|||
.layer(Extension(server))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn handle_websocket_request(
|
||||
TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
|
||||
app_version_header: Option<TypedHeader<AppVersionHeader>>,
|
||||
|
@ -1063,6 +1067,7 @@ pub async fn handle_websocket_request(
|
|||
Extension(server): Extension<Arc<Server>>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
system_id_header: Option<TypedHeader<SystemIdHeader>>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> axum::response::Response {
|
||||
if protocol_version != rpc::PROTOCOL_VERSION {
|
||||
|
@ -1104,6 +1109,7 @@ pub async fn handle_websocket_request(
|
|||
principal,
|
||||
version,
|
||||
country_code_header.map(|header| header.to_string()),
|
||||
system_id_header.map(|header| header.to_string()),
|
||||
None,
|
||||
Executor::Production,
|
||||
)
|
||||
|
@ -4053,6 +4059,7 @@ async fn get_llm_api_token(
|
|||
has_llm_closed_beta_feature_flag,
|
||||
has_llm_subscription,
|
||||
session.current_plan(&db).await?,
|
||||
session.system_id.clone(),
|
||||
&session.app_state.config,
|
||||
)?;
|
||||
response.send(proto::GetLlmTokenResponse { token })?;
|
||||
|
|
|
@ -244,6 +244,7 @@ impl TestServer {
|
|||
Principal::User(user),
|
||||
ZedVersion(SemanticVersion::new(1, 0, 0)),
|
||||
None,
|
||||
None,
|
||||
Some(connection_id_tx),
|
||||
Executor::Deterministic(cx.background_executor().clone()),
|
||||
))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue