Store the impersonator id on access tokens created via ZED_IMPERSONATE

* Use the impersonator id to prevent these tokens from counting
  against the impersonated user when limiting the users' total
  of access tokens.
* When connecting using an access token with an impersonator
  add the impersonator as a field to the tracing span that wraps
  the task for that connection.
* Disallow impersonating users via the admin API token in production,
  because when using the admin API token, we aren't able to identify
  the impersonator.

Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-01-17 15:46:36 -08:00
parent 9521f49160
commit ab1bea515c
9 changed files with 198 additions and 39 deletions

View file

@ -1,7 +1,7 @@
mod connection_pool;
use crate::{
auth,
auth::{self, Impersonator},
db::{
self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
@ -65,7 +65,7 @@ use std::{
use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore};
use tower::ServiceBuilder;
use tracing::{info_span, instrument, Instrument};
use tracing::{field, info_span, instrument, Instrument};
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
@ -561,13 +561,17 @@ impl Server {
connection: Connection,
address: String,
user: User,
impersonator: Option<User>,
mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor,
) -> impl Future<Output = Result<()>> {
let this = self.clone();
let user_id = user.id;
let login = user.github_login;
let span = info_span!("handle connection", %user_id, %login, %address);
let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
if let Some(impersonator) = impersonator {
span.record("impersonator", &impersonator.github_login);
}
let mut teardown = self.teardown.subscribe();
async move {
let (connection_id, handle_io, mut incoming_rx) = this
@ -839,6 +843,7 @@ pub async fn handle_websocket_request(
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
Extension(server): Extension<Arc<Server>>,
Extension(user): Extension<User>,
Extension(impersonator): Extension<Impersonator>,
ws: WebSocketUpgrade,
) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION {
@ -858,7 +863,14 @@ pub async fn handle_websocket_request(
let connection = Connection::new(Box::pin(socket));
async move {
server
.handle_connection(connection, socket_address, user, None, Executor::Production)
.handle_connection(
connection,
socket_address,
user,
impersonator.0,
None,
Executor::Production,
)
.await
.log_err();
}