Send llm events to snowflake too (#21091)

Closes #ISSUE

Release Notes:

- N/A
This commit is contained in:
Conrad Irwin 2024-11-22 20:40:39 -07:00 committed by GitHub
parent 5766afe710
commit 984bb192ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 183 additions and 25 deletions

View file

@ -1,6 +1,6 @@
mod connection_pool;
use crate::api::CloudflareIpCountryHeader;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::llm::LlmTokenClaims;
use crate::{
auth,
@ -137,6 +137,7 @@ struct Session {
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
system_id: Option<String>,
_executor: Executor,
}
@ -682,6 +683,7 @@ impl Server {
principal: Principal,
zed_version: ZedVersion,
geoip_country_code: Option<String>,
system_id: Option<String>,
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor,
) -> impl Future<Output = ()> {
@ -737,6 +739,7 @@ impl Server {
app_state: this.app_state.clone(),
http_client,
geoip_country_code,
system_id,
_executor: executor.clone(),
supermaven_client,
};
@ -1056,6 +1059,7 @@ pub fn routes(server: Arc<Server>) -> Router<(), Body> {
.layer(Extension(server))
}
#[allow(clippy::too_many_arguments)]
pub async fn handle_websocket_request(
TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
app_version_header: Option<TypedHeader<AppVersionHeader>>,
@ -1063,6 +1067,7 @@ pub async fn handle_websocket_request(
Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
system_id_header: Option<TypedHeader<SystemIdHeader>>,
ws: WebSocketUpgrade,
) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION {
@ -1104,6 +1109,7 @@ pub async fn handle_websocket_request(
principal,
version,
country_code_header.map(|header| header.to_string()),
system_id_header.map(|header| header.to_string()),
None,
Executor::Production,
)
@ -4053,6 +4059,7 @@ async fn get_llm_api_token(
has_llm_closed_beta_feature_flag,
has_llm_subscription,
session.current_plan(&db).await?,
session.system_id.clone(),
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;