collab: Attach GeoIP country code to RPC sessions (#15814)
This PR updates collab to attach the user's GeoIP country code to their RPC session. We source the country code from the [`CF-IPCountry`](https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry) header. Release Notes: - N/A
This commit is contained in:
parent
be0ccf47ee
commit
f11f3f2599
4 changed files with 52 additions and 31 deletions
|
@ -14,7 +14,8 @@ use anyhow::anyhow;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
extract::{Path, Query},
|
extract::{Path, Query},
|
||||||
http::{self, Request, StatusCode},
|
headers::Header,
|
||||||
|
http::{self, HeaderName, Request, StatusCode},
|
||||||
middleware::{self, Next},
|
middleware::{self, Next},
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
|
@ -22,11 +23,44 @@ use axum::{
|
||||||
};
|
};
|
||||||
use axum_extra::response::ErasedJson;
|
use axum_extra::response::ErasedJson;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, OnceLock};
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
|
|
||||||
pub use extensions::fetch_extensions_from_blob_store_periodically;
|
pub use extensions::fetch_extensions_from_blob_store_periodically;
|
||||||
|
|
||||||
|
pub struct CloudflareIpCountryHeader(String);
|
||||||
|
|
||||||
|
impl Header for CloudflareIpCountryHeader {
|
||||||
|
fn name() -> &'static HeaderName {
|
||||||
|
static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
|
||||||
|
CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
I: Iterator<Item = &'i axum::http::HeaderValue>,
|
||||||
|
{
|
||||||
|
let country_code = values
|
||||||
|
.next()
|
||||||
|
.ok_or_else(axum::headers::Error::invalid)?
|
||||||
|
.to_str()
|
||||||
|
.map_err(|_| axum::headers::Error::invalid())?;
|
||||||
|
|
||||||
|
Ok(Self(country_code.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for CloudflareIpCountryHeader {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
|
pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/user", get(get_authenticated_user))
|
.route("/user", get(get_authenticated_user))
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use super::ips_file::IpsFile;
|
use super::ips_file::IpsFile;
|
||||||
|
use crate::api::CloudflareIpCountryHeader;
|
||||||
use crate::{api::slack, AppState, Error, Result};
|
use crate::{api::slack, AppState, Error, Result};
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use aws_sdk_s3::primitives::ByteStream;
|
use aws_sdk_s3::primitives::ByteStream;
|
||||||
|
@ -59,33 +60,6 @@ impl Header for ZedChecksumHeader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CloudflareIpCountryHeader(String);
|
|
||||||
|
|
||||||
impl Header for CloudflareIpCountryHeader {
|
|
||||||
fn name() -> &'static HeaderName {
|
|
||||||
static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
|
|
||||||
CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
I: Iterator<Item = &'i axum::http::HeaderValue>,
|
|
||||||
{
|
|
||||||
let country_code = values
|
|
||||||
.next()
|
|
||||||
.ok_or_else(axum::headers::Error::invalid)?
|
|
||||||
.to_str()
|
|
||||||
.map_err(|_| axum::headers::Error::invalid())?;
|
|
||||||
|
|
||||||
Ok(Self(country_code.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn post_crash(
|
pub async fn post_crash(
|
||||||
Extension(app): Extension<Arc<AppState>>,
|
Extension(app): Extension<Arc<AppState>>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
|
@ -413,7 +387,7 @@ pub async fn post_events(
|
||||||
let Some(last_event) = request_body.events.last() else {
|
let Some(last_event) = request_body.events.last() else {
|
||||||
return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
|
return Err(Error::Http(StatusCode::BAD_REQUEST, "no events".into()))?;
|
||||||
};
|
};
|
||||||
let country_code = country_code_header.map(|h| h.0 .0);
|
let country_code = country_code_header.map(|h| h.to_string());
|
||||||
|
|
||||||
let first_event_at = chrono::Utc::now()
|
let first_event_at = chrono::Utc::now()
|
||||||
- chrono::Duration::milliseconds(last_event.milliseconds_since_first_event);
|
- chrono::Duration::milliseconds(last_event.milliseconds_since_first_event);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
mod connection_pool;
|
mod connection_pool;
|
||||||
|
|
||||||
|
use crate::api::CloudflareIpCountryHeader;
|
||||||
use crate::{
|
use crate::{
|
||||||
auth,
|
auth,
|
||||||
db::{
|
db::{
|
||||||
|
@ -152,6 +153,9 @@ struct Session {
|
||||||
supermaven_client: Option<Arc<SupermavenAdminApi>>,
|
supermaven_client: Option<Arc<SupermavenAdminApi>>,
|
||||||
http_client: Arc<IsahcHttpClient>,
|
http_client: Arc<IsahcHttpClient>,
|
||||||
rate_limiter: Arc<RateLimiter>,
|
rate_limiter: Arc<RateLimiter>,
|
||||||
|
/// The GeoIP country code for the user.
|
||||||
|
#[allow(unused)]
|
||||||
|
geoip_country_code: Option<String>,
|
||||||
_executor: Executor,
|
_executor: Executor,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -984,6 +988,7 @@ impl Server {
|
||||||
address: String,
|
address: String,
|
||||||
principal: Principal,
|
principal: Principal,
|
||||||
zed_version: ZedVersion,
|
zed_version: ZedVersion,
|
||||||
|
geoip_country_code: Option<String>,
|
||||||
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
|
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
|
||||||
executor: Executor,
|
executor: Executor,
|
||||||
) -> impl Future<Output = ()> {
|
) -> impl Future<Output = ()> {
|
||||||
|
@ -1009,7 +1014,10 @@ impl Server {
|
||||||
let executor = executor.clone();
|
let executor = executor.clone();
|
||||||
move |duration| executor.sleep(duration)
|
move |duration| executor.sleep(duration)
|
||||||
});
|
});
|
||||||
tracing::Span::current().record("connection_id", format!("{}", connection_id));
|
tracing::Span::current()
|
||||||
|
.record("connection_id", format!("{}", connection_id))
|
||||||
|
.record("geoip_country_code", geoip_country_code.clone());
|
||||||
|
|
||||||
tracing::info!("connection opened");
|
tracing::info!("connection opened");
|
||||||
|
|
||||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||||
|
@ -1039,6 +1047,7 @@ impl Server {
|
||||||
live_kit_client: this.app_state.live_kit_client.clone(),
|
live_kit_client: this.app_state.live_kit_client.clone(),
|
||||||
http_client,
|
http_client,
|
||||||
rate_limiter: this.app_state.rate_limiter.clone(),
|
rate_limiter: this.app_state.rate_limiter.clone(),
|
||||||
|
geoip_country_code,
|
||||||
_executor: executor.clone(),
|
_executor: executor.clone(),
|
||||||
supermaven_client,
|
supermaven_client,
|
||||||
};
|
};
|
||||||
|
@ -1395,6 +1404,7 @@ pub async fn handle_websocket_request(
|
||||||
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
|
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
|
||||||
Extension(server): Extension<Arc<Server>>,
|
Extension(server): Extension<Arc<Server>>,
|
||||||
Extension(principal): Extension<Principal>,
|
Extension(principal): Extension<Principal>,
|
||||||
|
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||||
ws: WebSocketUpgrade,
|
ws: WebSocketUpgrade,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
if protocol_version != rpc::PROTOCOL_VERSION {
|
if protocol_version != rpc::PROTOCOL_VERSION {
|
||||||
|
@ -1435,6 +1445,7 @@ pub async fn handle_websocket_request(
|
||||||
socket_address,
|
socket_address,
|
||||||
principal,
|
principal,
|
||||||
version,
|
version,
|
||||||
|
country_code_header.map(|header| header.to_string()),
|
||||||
None,
|
None,
|
||||||
Executor::Production,
|
Executor::Production,
|
||||||
)
|
)
|
||||||
|
|
|
@ -244,6 +244,7 @@ impl TestServer {
|
||||||
client_name,
|
client_name,
|
||||||
Principal::User(user),
|
Principal::User(user),
|
||||||
ZedVersion(SemanticVersion::new(1, 0, 0)),
|
ZedVersion(SemanticVersion::new(1, 0, 0)),
|
||||||
|
None,
|
||||||
Some(connection_id_tx),
|
Some(connection_id_tx),
|
||||||
Executor::Deterministic(cx.background_executor().clone()),
|
Executor::Deterministic(cx.background_executor().clone()),
|
||||||
))
|
))
|
||||||
|
@ -377,6 +378,7 @@ impl TestServer {
|
||||||
"dev-server".to_string(),
|
"dev-server".to_string(),
|
||||||
Principal::DevServer(dev_server),
|
Principal::DevServer(dev_server),
|
||||||
ZedVersion(SemanticVersion::new(1, 0, 0)),
|
ZedVersion(SemanticVersion::new(1, 0, 0)),
|
||||||
|
None,
|
||||||
Some(connection_id_tx),
|
Some(connection_id_tx),
|
||||||
Executor::Deterministic(cx.background_executor().clone()),
|
Executor::Deterministic(cx.background_executor().clone()),
|
||||||
))
|
))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue