collab: Attach GeoIP country code to RPC sessions (#15814)

This PR updates collab to attach the user's GeoIP country code to their
RPC session.

We source the country code from the
[`CF-IPCountry`](https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry)
header.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-05 11:11:49 -04:00 committed by GitHub
parent be0ccf47ee
commit f11f3f2599
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 52 additions and 31 deletions

View file

@ -1,5 +1,6 @@
mod connection_pool;
use crate::api::CloudflareIpCountryHeader;
use crate::{
auth,
db::{
@ -152,6 +153,9 @@ struct Session {
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<IsahcHttpClient>,
rate_limiter: Arc<RateLimiter>,
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
_executor: Executor,
}
@ -984,6 +988,7 @@ impl Server {
address: String,
principal: Principal,
zed_version: ZedVersion,
geoip_country_code: Option<String>,
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor,
) -> impl Future<Output = ()> {
@ -1009,7 +1014,10 @@ impl Server {
let executor = executor.clone();
move |duration| executor.sleep(duration)
});
tracing::Span::current().record("connection_id", format!("{}", connection_id));
tracing::Span::current()
.record("connection_id", format!("{}", connection_id))
.record("geoip_country_code", geoip_country_code.clone());
tracing::info!("connection opened");
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
@ -1039,6 +1047,7 @@ impl Server {
live_kit_client: this.app_state.live_kit_client.clone(),
http_client,
rate_limiter: this.app_state.rate_limiter.clone(),
geoip_country_code,
_executor: executor.clone(),
supermaven_client,
};
@ -1395,6 +1404,7 @@ pub async fn handle_websocket_request(
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
ws: WebSocketUpgrade,
) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION {
@ -1435,6 +1445,7 @@ pub async fn handle_websocket_request(
socket_address,
principal,
version,
country_code_header.map(|header| header.to_string()),
None,
Executor::Production,
)