From cb4f86881556e817067b4c8c839c241a26e078c2 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Fri, 22 Mar 2024 08:44:56 -0600 Subject: [PATCH] remoting (#9680) This PR provides some of the plumbing needed for a "remote" zed instance. The way this will work is: * From zed on your laptop you'll be able to manage a set of dev servers, each of which is identified by a token. * You'll run `zed --dev-server-token XXXX` to boot a remotable dev server. * From the zed on your laptop you'll be able to open directories and work on the projects on the remote server (exactly like collaboration works today). For now all this PR does is provide the ability for a zed instance to sign in using a "dev server token". The next steps will be: * Adding support to the collaboration protocol to instruct a dev server to "open" a directory and share it into a channel. * Adding UI to manage these servers and tokens (manually for now) Related #5347 Release Notes: - N/A --------- Co-authored-by: Nathan --- Cargo.lock | 1 + crates/client/src/client.rs | 73 ++- crates/client/src/test.rs | 11 +- crates/collab/README.md | 2 +- .../20221109000000_test_schema.sql | 8 + .../20240321162658_add_devservers.sql | 7 + crates/collab/src/auth.rs | 74 ++- crates/collab/src/db/ids.rs | 23 +- crates/collab/src/db/queries.rs | 1 + crates/collab/src/db/queries/dev_servers.rs | 18 + crates/collab/src/db/tables.rs | 1 + crates/collab/src/db/tables/dev_server.rs | 17 + crates/collab/src/rpc.rs | 532 +++++++++++------- crates/collab/src/tests/test_server.rs | 16 +- crates/theme/theme.md | 16 +- crates/zed/Cargo.toml | 1 + crates/zed/src/main.rs | 72 ++- script/create-migration | 3 + script/eula/eula.rtf | 2 +- 19 files changed, 582 insertions(+), 296 deletions(-) create mode 100644 crates/collab/migrations/20240321162658_add_devservers.sql create mode 100644 crates/collab/src/db/queries/dev_servers.rs create mode 100644 crates/collab/src/db/tables/dev_server.rs create mode 100755 script/create-migration diff --git a/Cargo.lock b/Cargo.lock index c6bd889ea7..033c99d1cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12562,6 +12562,7 @@ dependencies = [ "call", "channel", "chrono", + "clap 4.4.4", "cli", "client", "clock", diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 5abd530579..4c84935584 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -27,8 +27,8 @@ use release_channel::{AppVersion, ReleaseChannel}; use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; - use settings::{Settings, SettingsStore}; +use std::fmt; use std::{ any::TypeId, convert::TryFrom, @@ -52,6 +52,15 @@ pub use rpc::*; pub use telemetry_events::Event; pub use user::*; +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct DevServerToken(pub String); + +impl fmt::Display for DevServerToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + lazy_static! { static ref ZED_SERVER_URL: Option = std::env::var("ZED_SERVER_URL").ok(); static ref ZED_RPC_URL: Option = std::env::var("ZED_RPC_URL").ok(); @@ -277,10 +286,22 @@ enum WeakSubscriber { Pending(Vec>), } -#[derive(Clone, Debug)] -pub struct Credentials { - pub user_id: u64, - pub access_token: String, +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum Credentials { + DevServer { token: DevServerToken }, + User { user_id: u64, access_token: String }, +} + +impl Credentials { + pub fn authorization_header(&self) -> String { + match self { + Credentials::DevServer { token } => format!("dev-server-token {}", token), + Credentials::User { + user_id, + access_token, + } => format!("{} {}", user_id, access_token), + } + } } impl Default for ClientState { @@ -497,11 +518,11 @@ impl Client { } pub fn user_id(&self) -> Option { - self.state - .read() - .credentials - .as_ref() - .map(|credentials| credentials.user_id) + if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() { + Some(*user_id) + } else { + None + } } pub fn peer_id(&self) -> Option { @@ -746,6 +767,10 @@ impl Client { read_credentials_from_keychain(cx).await.is_some() } + pub fn set_dev_server_token(&self, token: DevServerToken) { + self.state.write().credentials = Some(Credentials::DevServer { token }); + } + #[async_recursion(?Send)] pub async fn authenticate_and_connect( self: &Arc, @@ -796,7 +821,9 @@ impl Client { } } let credentials = credentials.unwrap(); - self.set_id(credentials.user_id); + if let Credentials::User { user_id, .. } = &credentials { + self.set_id(*user_id); + } if was_disconnected { self.set_status(Status::Connecting, cx); @@ -812,7 +839,9 @@ impl Client { Ok(conn) => { self.state.write().credentials = Some(credentials.clone()); if !read_from_keychain && IMPERSONATE_LOGIN.is_none() { - write_credentials_to_keychain(credentials, cx).await.log_err(); + if let Credentials::User{user_id, access_token} = credentials { + write_credentials_to_keychain(user_id, access_token, cx).await.log_err(); + } } futures::select_biased! { @@ -1020,10 +1049,7 @@ impl Client { .unwrap_or_default(); let request = Request::builder() - .header( - "Authorization", - format!("{} {}", credentials.user_id, credentials.access_token), - ) + .header("Authorization", credentials.authorization_header()) .header("x-zed-protocol-version", rpc::PROTOCOL_VERSION) .header("x-zed-app-version", app_version) .header( @@ -1176,7 +1202,7 @@ impl Client { .decrypt_string(&access_token) .context("failed to decrypt access token")?; - Ok(Credentials { + Ok(Credentials::User { user_id: user_id.parse()?, access_token, }) @@ -1226,7 +1252,7 @@ impl Client { // Use the admin API token to authenticate as the impersonated user. api_token.insert_str(0, "ADMIN_TOKEN:"); - Ok(Credentials { + Ok(Credentials::User { user_id: response.user.id, access_token: api_token, }) @@ -1439,21 +1465,22 @@ async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option Result<()> { cx.update(move |cx| { cx.write_credentials( &ClientSettings::get_global(cx).server_url, - &credentials.user_id.to_string(), - credentials.access_token.as_bytes(), + &user_id.to_string(), + access_token.as_bytes(), ) })? .await @@ -1558,7 +1585,7 @@ mod tests { // Time out when client tries to connect. client.override_authenticate(move |cx| { cx.background_executor().spawn(async move { - Ok(Credentials { + Ok(Credentials::User { user_id, access_token: "token".into(), }) diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 9338e8cb91..5e8ad2181c 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -48,7 +48,7 @@ impl FakeServer { let mut state = state.lock(); state.auth_count += 1; let access_token = state.access_token.to_string(); - Ok(Credentials { + Ok(Credentials::User { user_id: client_user_id, access_token, }) @@ -71,9 +71,12 @@ impl FakeServer { )))? } - assert_eq!(credentials.user_id, client_user_id); - - if credentials.access_token != state.lock().access_token.to_string() { + if credentials + != (Credentials::User { + user_id: client_user_id, + access_token: state.lock().access_token.to_string(), + }) + { Err(EstablishConnectionError::Unauthorized)? } diff --git a/crates/collab/README.md b/crates/collab/README.md index bb3c76b15b..1af0b55d47 100644 --- a/crates/collab/README.md +++ b/crates/collab/README.md @@ -29,7 +29,7 @@ You can tell what is currently deployed with `./script/what-is-deployed`. To create a new migration: ``` -./script/sqlx migrate add +./script/create-migration ``` Migrations are run automatically on service start, so run `foreman start` again. The service will crash if the migrations fail. diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 29758d7eb1..d82ef75813 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -400,3 +400,11 @@ CREATE TABLE hosted_projects ( ); CREATE INDEX idx_hosted_projects_on_channel_id ON hosted_projects (channel_id); CREATE UNIQUE INDEX uix_hosted_projects_on_channel_id_and_name ON hosted_projects (channel_id, name) WHERE (deleted_at IS NULL); + +CREATE TABLE dev_servers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL REFERENCES channels(id), + name TEXT NOT NULL, + hashed_token TEXT NOT NULL +); +CREATE INDEX idx_dev_servers_on_channel_id ON dev_servers (channel_id); diff --git a/crates/collab/migrations/20240321162658_add_devservers.sql b/crates/collab/migrations/20240321162658_add_devservers.sql new file mode 100644 index 0000000000..cb1ff4df40 --- /dev/null +++ b/crates/collab/migrations/20240321162658_add_devservers.sql @@ -0,0 +1,7 @@ +CREATE TABLE dev_servers ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + channel_id INT NOT NULL REFERENCES channels(id), + name TEXT NOT NULL, + hashed_token TEXT NOT NULL +); +CREATE INDEX idx_dev_servers_on_channel_id ON dev_servers (channel_id); diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 26f6ede3d3..5daf6e6186 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -1,5 +1,6 @@ use crate::{ - db::{self, AccessTokenId, Database, UserId}, + db::{self, dev_server, AccessTokenId, Database, DevServerId, UserId}, + rpc::Principal, AppState, Error, Result, }; use anyhow::{anyhow, Context}; @@ -19,11 +20,11 @@ use std::sync::OnceLock; use std::{sync::Arc, time::Instant}; use subtle::ConstantTimeEq; -#[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct Impersonator(pub Option); - -/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN -/// and one for the access tokens that we issue. +/// Validates the authorization header and adds an Extension to the request. +/// Authorization: +/// can be an access_token attached to that user, or an access token of an admin +/// or (in development) the string ADMIN:. +/// Authorization: "dev-server-token" pub async fn validate_header(mut req: Request, next: Next) -> impl IntoResponse { let mut auth_header = req .headers() @@ -37,7 +38,26 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into })? .split_whitespace(); - let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| { + let state = req.extensions().get::>().unwrap(); + + let first = auth_header.next().unwrap_or(""); + if first == "dev-server-token" { + let dev_server_token = auth_header.next().ok_or_else(|| { + Error::Http( + StatusCode::BAD_REQUEST, + "missing dev-server-token token in authorization header".to_string(), + ) + })?; + let dev_server = verify_dev_server_token(dev_server_token, &state.db) + .await + .map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?; + + req.extensions_mut() + .insert(Principal::DevServer(dev_server)); + return Ok::<_, Error>(next.run(req).await); + } + + let user_id = UserId(first.parse().map_err(|_| { Error::Http( StatusCode::BAD_REQUEST, "missing user id in authorization header".to_string(), @@ -51,8 +71,6 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into ) })?; - let state = req.extensions().get::>().unwrap(); - // In development, allow impersonation using the admin API token. // Don't allow this in production because we can't tell who is doing // the impersonating. @@ -76,18 +94,17 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into .await? .ok_or_else(|| anyhow!("user {} not found", user_id))?; - let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id { - let impersonator = state + if let Some(impersonator_id) = validate_result.impersonator_id { + let admin = state .db .get_user_by_id(impersonator_id) .await? .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?; - Some(impersonator) + req.extensions_mut() + .insert(Principal::Impersonated { user, admin }); } else { - None + req.extensions_mut().insert(Principal::User(user)); }; - req.extensions_mut().insert(user); - req.extensions_mut().insert(Impersonator(impersonator)); return Ok::<_, Error>(next.run(req).await); } } @@ -213,6 +230,33 @@ pub async fn verify_access_token( }) } +// a dev_server_token has the format .. This is to make them +// relatively easy to copy/paste around. +pub async fn verify_dev_server_token( + dev_server_token: &str, + db: &Arc, +) -> anyhow::Result { + let mut parts = dev_server_token.splitn(2, '.'); + let id = DevServerId(parts.next().unwrap_or_default().parse()?); + let token = parts + .next() + .ok_or_else(|| anyhow!("invalid dev server token format"))?; + + let token_hash = hash_access_token(&token); + let server = db.get_dev_server(id).await?; + + if server + .hashed_token + .as_bytes() + .ct_eq(token_hash.as_ref()) + .into() + { + Ok(server) + } else { + Err(anyhow!("wrong token for dev server")) + } +} + #[cfg(test)] mod test { use rand::thread_rng; diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index f465d3812a..91c0c440a5 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -67,28 +67,29 @@ macro_rules! id_type { }; } -id_type!(BufferId); id_type!(AccessTokenId); +id_type!(BufferId); +id_type!(ChannelBufferCollaboratorId); id_type!(ChannelChatParticipantId); id_type!(ChannelId); id_type!(ChannelMemberId); -id_type!(MessageId); id_type!(ContactId); +id_type!(DevServerId); +id_type!(ExtensionId); +id_type!(FlagId); id_type!(FollowerId); +id_type!(HostedProjectId); +id_type!(MessageId); +id_type!(NotificationId); +id_type!(NotificationKindId); +id_type!(ProjectCollaboratorId); +id_type!(ProjectId); +id_type!(ReplicaId); id_type!(RoomId); id_type!(RoomParticipantId); -id_type!(ProjectId); -id_type!(ProjectCollaboratorId); -id_type!(ReplicaId); id_type!(ServerId); id_type!(SignupId); id_type!(UserId); -id_type!(ChannelBufferCollaboratorId); -id_type!(FlagId); -id_type!(ExtensionId); -id_type!(NotificationId); -id_type!(NotificationKindId); -id_type!(HostedProjectId); /// ChannelRole gives you permissions for both channels and calls. #[derive( diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 7f2e345a59..0582b8f256 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -5,6 +5,7 @@ pub mod buffers; pub mod channels; pub mod contacts; pub mod contributors; +pub mod dev_servers; pub mod extensions; pub mod hosted_projects; pub mod messages; diff --git a/crates/collab/src/db/queries/dev_servers.rs b/crates/collab/src/db/queries/dev_servers.rs new file mode 100644 index 0000000000..d95897b51e --- /dev/null +++ b/crates/collab/src/db/queries/dev_servers.rs @@ -0,0 +1,18 @@ +use sea_orm::EntityTrait; + +use super::{dev_server, Database, DevServerId}; + +impl Database { + pub async fn get_dev_server( + &self, + dev_server_id: DevServerId, + ) -> crate::Result { + self.transaction(|tx| async move { + Ok(dev_server::Entity::find_by_id(dev_server_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow::anyhow!("no dev server with id {}", dev_server_id))?) + }) + .await + } +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index 6864cc3782..b679337943 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -10,6 +10,7 @@ pub mod channel_message; pub mod channel_message_mention; pub mod contact; pub mod contributor; +pub mod dev_server; pub mod extension; pub mod extension_version; pub mod feature_flag; diff --git a/crates/collab/src/db/tables/dev_server.rs b/crates/collab/src/db/tables/dev_server.rs new file mode 100644 index 0000000000..94b1d4dc00 --- /dev/null +++ b/crates/collab/src/db/tables/dev_server.rs @@ -0,0 +1,17 @@ +use crate::db::{ChannelId, DevServerId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "dev_servers")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: DevServerId, + pub name: String, + pub channel_id: ChannelId, + pub hashed_token: String, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 735e1d3c50..9545b0c2e4 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,12 +1,12 @@ mod connection_pool; use crate::{ - auth::{self, Impersonator}, + auth::{self}, db::{ - self, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, - Database, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project, - ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId, - UpdatedChannelMessage, User, UserId, + self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, + CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId, + NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId, + RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId, }, executor::Executor, AppState, Error, RateLimit, RateLimiter, Result, @@ -64,7 +64,10 @@ use std::{ use time::OffsetDateTime; use tokio::sync::{watch, Semaphore}; use tower::ServiceBuilder; -use tracing::{field, info_span, instrument, Instrument}; +use tracing::{ + field::{self}, + info_span, instrument, Instrument, +}; use util::{http::IsahcHttpClient, SemanticVersion}; pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); @@ -105,9 +108,35 @@ impl StreamingResponse { } } +#[derive(Clone, Debug)] +pub enum Principal { + User(User), + Impersonated { user: User, admin: User }, + DevServer(dev_server::Model), +} + +impl Principal { + fn update_span(&self, span: &tracing::Span) { + match &self { + Principal::User(user) => { + span.record("user_id", &user.id.0); + span.record("login", &user.github_login); + } + Principal::Impersonated { user, admin } => { + span.record("user_id", &user.id.0); + span.record("login", &user.github_login); + span.record("impersonator", &admin.github_login); + } + Principal::DevServer(dev_server) => { + span.record("dev_server_id", &dev_server.id.0); + } + } + } +} + #[derive(Clone)] struct Session { - user_id: UserId, + principal: Principal, connection_id: ConnectionId, db: Arc>, peer: Arc, @@ -137,14 +166,98 @@ impl Session { _not_send: PhantomData, } } + + fn for_user(self) -> Option { + UserSession::new(self) + } + + fn user_id(&self) -> Option { + match &self.principal { + Principal::User(user) => Some(user.id), + Principal::Impersonated { user, .. } => Some(user.id), + Principal::DevServer(_) => None, + } + } } impl Debug for Session { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("Session") - .field("user_id", &self.user_id) - .field("connection_id", &self.connection_id) - .finish() + let mut result = f.debug_struct("Session"); + match &self.principal { + Principal::User(user) => { + result.field("user", &user.github_login); + } + Principal::Impersonated { user, admin } => { + result.field("user", &user.github_login); + result.field("impersonator", &admin.github_login); + } + Principal::DevServer(dev_server) => { + result.field("dev_server", &dev_server.id); + } + } + result.field("connection_id", &self.connection_id).finish() + } +} + +struct UserSession(Session); + +impl UserSession { + pub fn new(s: Session) -> Option { + s.user_id().map(|_| UserSession(s)) + } + pub fn user_id(&self) -> UserId { + self.0.user_id().unwrap() + } +} + +impl Deref for UserSession { + type Target = Session; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for UserSession { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +fn user_handler( + handler: impl 'static + Send + Sync + Fn(M, Response, UserSession) -> Fut, +) -> impl 'static + Send + Sync + Fn(M, Response, Session) -> BoxFuture<'static, Result<()>> +where + Fut: Send + Future>, +{ + let handler = Arc::new(handler); + move |message, response, session| { + let handler = handler.clone(); + Box::pin(async move { + if let Some(user_session) = session.for_user() { + Ok(handler(message, response, user_session).await?) + } else { + Err(Error::Internal(anyhow!("must be a user"))) + } + }) + } +} + +fn user_message_handler( + handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut, +) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>> +where + InnertRetFut: Send + Future>, +{ + let handler = Arc::new(handler); + move |message, session| { + let handler = handler.clone(); + Box::pin(async move { + if let Some(user_session) = session.for_user() { + Ok(handler(message, user_session).await?) + } else { + Err(Error::Internal(anyhow!("must be a user"))) + } + }) } } @@ -201,20 +314,20 @@ impl Server { server .add_request_handler(ping) - .add_request_handler(create_room) - .add_request_handler(join_room) - .add_request_handler(rejoin_room) - .add_request_handler(leave_room) - .add_request_handler(set_room_participant_role) - .add_request_handler(call) - .add_request_handler(cancel_call) - .add_message_handler(decline_call) - .add_request_handler(update_participant_location) + .add_request_handler(user_handler(create_room)) + .add_request_handler(user_handler(join_room)) + .add_request_handler(user_handler(rejoin_room)) + .add_request_handler(user_handler(leave_room)) + .add_request_handler(user_handler(set_room_participant_role)) + .add_request_handler(user_handler(call)) + .add_request_handler(user_handler(cancel_call)) + .add_message_handler(user_message_handler(decline_call)) + .add_request_handler(user_handler(update_participant_location)) .add_request_handler(share_project) .add_message_handler(unshare_project) - .add_request_handler(join_project) - .add_request_handler(join_hosted_project) - .add_message_handler(leave_project) + .add_request_handler(user_handler(join_project)) + .add_request_handler(user_handler(join_hosted_project)) + .add_message_handler(user_message_handler(leave_project)) .add_request_handler(update_project) .add_request_handler(update_worktree) .add_message_handler(start_language_server) @@ -261,40 +374,40 @@ impl Server { .add_message_handler(broadcast_project_message_from_host::) .add_message_handler(broadcast_project_message_from_host::) .add_request_handler(get_users) - .add_request_handler(fuzzy_search_users) - .add_request_handler(request_contact) - .add_request_handler(remove_contact) - .add_request_handler(respond_to_contact_request) - .add_request_handler(create_channel) - .add_request_handler(delete_channel) - .add_request_handler(invite_channel_member) - .add_request_handler(remove_channel_member) - .add_request_handler(set_channel_member_role) - .add_request_handler(set_channel_visibility) - .add_request_handler(rename_channel) - .add_request_handler(join_channel_buffer) - .add_request_handler(leave_channel_buffer) - .add_message_handler(update_channel_buffer) - .add_request_handler(rejoin_channel_buffers) - .add_request_handler(get_channel_members) - .add_request_handler(respond_to_channel_invite) - .add_request_handler(join_channel) - .add_request_handler(join_channel_chat) - .add_message_handler(leave_channel_chat) - .add_request_handler(send_channel_message) - .add_request_handler(remove_channel_message) - .add_request_handler(update_channel_message) - .add_request_handler(get_channel_messages) - .add_request_handler(get_channel_messages_by_id) - .add_request_handler(get_notifications) - .add_request_handler(mark_notification_as_read) - .add_request_handler(move_channel) - .add_request_handler(follow) - .add_message_handler(unfollow) - .add_message_handler(update_followers) - .add_request_handler(get_private_user_info) - .add_message_handler(acknowledge_channel_message) - .add_message_handler(acknowledge_buffer_version) + .add_request_handler(user_handler(fuzzy_search_users)) + .add_request_handler(user_handler(request_contact)) + .add_request_handler(user_handler(remove_contact)) + .add_request_handler(user_handler(respond_to_contact_request)) + .add_request_handler(user_handler(create_channel)) + .add_request_handler(user_handler(delete_channel)) + .add_request_handler(user_handler(invite_channel_member)) + .add_request_handler(user_handler(remove_channel_member)) + .add_request_handler(user_handler(set_channel_member_role)) + .add_request_handler(user_handler(set_channel_visibility)) + .add_request_handler(user_handler(rename_channel)) + .add_request_handler(user_handler(join_channel_buffer)) + .add_request_handler(user_handler(leave_channel_buffer)) + .add_message_handler(user_message_handler(update_channel_buffer)) + .add_request_handler(user_handler(rejoin_channel_buffers)) + .add_request_handler(user_handler(get_channel_members)) + .add_request_handler(user_handler(respond_to_channel_invite)) + .add_request_handler(user_handler(join_channel)) + .add_request_handler(user_handler(join_channel_chat)) + .add_message_handler(user_message_handler(leave_channel_chat)) + .add_request_handler(user_handler(send_channel_message)) + .add_request_handler(user_handler(remove_channel_message)) + .add_request_handler(user_handler(update_channel_message)) + .add_request_handler(user_handler(get_channel_messages)) + .add_request_handler(user_handler(get_channel_messages_by_id)) + .add_request_handler(user_handler(get_notifications)) + .add_request_handler(user_handler(mark_notification_as_read)) + .add_request_handler(user_handler(move_channel)) + .add_request_handler(user_handler(follow)) + .add_message_handler(user_message_handler(unfollow)) + .add_message_handler(user_message_handler(update_followers)) + .add_request_handler(user_handler(get_private_user_info)) + .add_message_handler(user_message_handler(acknowledge_channel_message)) + .add_message_handler(user_message_handler(acknowledge_buffer_version)) .add_streaming_request_handler({ let app_state = app_state.clone(); move |request, response, session| { @@ -309,14 +422,14 @@ impl Server { }) .add_request_handler({ let app_state = app_state.clone(); - move |request, response, session| { + user_handler(move |request, response, session| { count_tokens_with_language_model( request, response, session, app_state.config.google_ai_api_key.clone(), ) - } + }) }); Arc::new(server) @@ -612,19 +725,15 @@ impl Server { self: &Arc, connection: Connection, address: String, - user: User, + principal: Principal, zed_version: ZedVersion, - impersonator: Option, send_connection_id: Option>, executor: Executor, ) -> impl Future { let this = self.clone(); - let user_id = user.id; - let login = user.github_login.clone(); - let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty, connection_id = field::Empty); - if let Some(impersonator) = impersonator { - span.record("impersonator", &impersonator.github_login); - } + let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty); + principal.update_span(&span); + let mut teardown = self.teardown.subscribe(); async move { if *teardown.borrow() { @@ -649,7 +758,7 @@ impl Server { }; let session = Session { - user_id, + principal: principal.clone(), connection_id, db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))), peer: this.peer.clone(), @@ -660,7 +769,7 @@ impl Server { _executor: executor.clone(), }; - if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await { + if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await { tracing::error!(?error, "failed to send initial client update"); return; } @@ -700,7 +809,8 @@ impl Server { let type_name = message.payload_type_name(); // note: we copy all the fields from the parent span so we can query them in the logs. // (https://github.com/tokio-rs/tracing/issues/2670). - let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name); + let span = tracing::info_span!("receive message", %connection_id, %address, type_name); + principal.update_span(&span); let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); @@ -739,7 +849,7 @@ impl Server { async fn send_initial_client_update( &self, connection_id: ConnectionId, - user: User, + principal: &Principal, zed_version: ZedVersion, mut send_connection_id: Option>, session: &Session, @@ -752,6 +862,10 @@ impl Server { )?; tracing::info!("sent hello message"); + let Principal::User(user) = principal else { + return Ok(()); + }; + if let Some(send_connection_id) = send_connection_id.take() { let _ = send_connection_id.send(connection_id); } @@ -970,8 +1084,7 @@ pub async fn handle_websocket_request( app_version_header: Option>, ConnectInfo(socket_address): ConnectInfo, Extension(server): Extension>, - Extension(user): Extension, - Extension(impersonator): Extension, + Extension(principal): Extension, ws: WebSocketUpgrade, ) -> axum::response::Response { if protocol_version != rpc::PROTOCOL_VERSION { @@ -1010,9 +1123,8 @@ pub async fn handle_websocket_request( .handle_connection( connection, socket_address, - user, + principal, version, - impersonator.0, None, Executor::Production, ) @@ -1075,24 +1187,26 @@ async fn connection_lost( futures::select_biased! { _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => { - log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id); - leave_room_for_session(&session).await.trace_err(); - leave_channel_buffers_for_session(&session) - .await - .trace_err(); + if let Some(session) = session.for_user() { + log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id); + leave_room_for_session(&session).await.trace_err(); + leave_channel_buffers_for_session(&session) + .await + .trace_err(); - if !session - .connection_pool() - .await - .is_user_online(session.user_id) - { - let db = session.db().await; - if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() { - room_updated(&room, &session.peer); + if !session + .connection_pool() + .await + .is_user_online(session.user_id()) + { + let db = session.db().await; + if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() { + room_updated(&room, &session.peer); + } } - } - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; + } } _ = teardown.changed().fuse() => {} } @@ -1110,19 +1224,20 @@ async fn ping(_: proto::Ping, response: Response, _session: Session async fn create_room( _request: proto::CreateRoom, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let live_kit_room = nanoid::nanoid!(30); let live_kit_connection_info = { let live_kit_room = live_kit_room.clone(); let live_kit = session.live_kit_client.as_ref(); + let user_id = session.user_id().to_string(); util::async_maybe!({ let live_kit = live_kit?; let token = live_kit - .room_token(&live_kit_room, &session.user_id.to_string()) + .room_token(&live_kit_room, &user_id.to_string()) .trace_err()?; Some(proto::LiveKitConnectionInfo { @@ -1137,7 +1252,7 @@ async fn create_room( let room = session .db() .await - .create_room(session.user_id, session.connection_id, &live_kit_room) + .create_room(session.user_id(), session.connection_id, &live_kit_room) .await?; response.send(proto::CreateRoomResponse { @@ -1145,7 +1260,7 @@ async fn create_room( live_kit_connection_info, })?; - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -1153,7 +1268,7 @@ async fn create_room( async fn join_room( request: proto::JoinRoom, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let room_id = RoomId::from_proto(request.id); @@ -1167,7 +1282,7 @@ async fn join_room( let room = session .db() .await - .join_room(room_id, session.user_id, session.connection_id) + .join_room(room_id, session.user_id(), session.connection_id) .await?; room_updated(&room.room, &session.peer); room.into_inner() @@ -1176,7 +1291,7 @@ async fn join_room( for connection_id in session .connection_pool() .await - .user_connection_ids(session.user_id) + .user_connection_ids(session.user_id()) { session .peer @@ -1193,7 +1308,7 @@ async fn join_room( if let Some(token) = live_kit .room_token( &joined_room.room.live_kit_room, - &session.user_id.to_string(), + &session.user_id().to_string(), ) .trace_err() { @@ -1215,7 +1330,7 @@ async fn join_room( live_kit_connection_info, })?; - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -1223,7 +1338,7 @@ async fn join_room( async fn rejoin_room( request: proto::RejoinRoom, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let room; let channel; @@ -1231,7 +1346,7 @@ async fn rejoin_room( let mut rejoined_room = session .db() .await - .rejoin_room(request, session.user_id, session.connection_id) + .rejoin_room(request, session.user_id(), session.connection_id) .await?; response.send(proto::RejoinRoomResponse { @@ -1404,7 +1519,7 @@ async fn rejoin_room( ); } - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -1412,7 +1527,7 @@ async fn rejoin_room( async fn leave_room( _: proto::LeaveRoom, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { leave_room_for_session(&session).await?; response.send(proto::Ack {})?; @@ -1423,7 +1538,7 @@ async fn leave_room( async fn set_room_participant_role( request: proto::SetRoomParticipantRole, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let user_id = UserId::from_proto(request.user_id); let role = ChannelRole::from(request.role()); @@ -1433,7 +1548,7 @@ async fn set_room_participant_role( .db() .await .set_room_participant_role( - session.user_id, + session.user_id(), RoomId::from_proto(request.room_id), user_id, role, @@ -1471,10 +1586,10 @@ async fn set_room_participant_role( async fn call( request: proto::Call, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); - let calling_user_id = session.user_id; + let calling_user_id = session.user_id(); let calling_connection_id = session.connection_id; let called_user_id = UserId::from_proto(request.called_user_id); let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); @@ -1540,7 +1655,7 @@ async fn call( async fn cancel_call( request: proto::CancelCall, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let called_user_id = UserId::from_proto(request.called_user_id); let room_id = RoomId::from_proto(request.room_id); @@ -1575,13 +1690,13 @@ async fn cancel_call( } /// Decline an incoming call. -async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { +async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> { let room_id = RoomId::from_proto(message.room_id); { let room = session .db() .await - .decline_call(Some(room_id), session.user_id) + .decline_call(Some(room_id), session.user_id()) .await? .ok_or_else(|| anyhow!("failed to decline call"))?; room_updated(&room, &session.peer); @@ -1590,7 +1705,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( for connection_id in session .connection_pool() .await - .user_connection_ids(session.user_id) + .user_connection_ids(session.user_id()) { session .peer @@ -1602,7 +1717,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( ) .trace_err(); } - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -1610,7 +1725,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( async fn update_participant_location( request: proto::UpdateParticipantLocation, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let location = request @@ -1674,7 +1789,7 @@ async fn unshare_project(message: proto::UnshareProject, session: Session) -> Re async fn join_project( request: proto::JoinProject, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); @@ -1705,7 +1820,7 @@ impl JoinProjectInternalResponse for Response { fn join_project_internal( response: impl JoinProjectInternalResponse, - session: Session, + session: UserSession, project: &mut Project, replica_id: &ReplicaId, ) -> Result<()> { @@ -1716,7 +1831,7 @@ fn join_project_internal( .map(|collaborator| collaborator.to_proto()) .collect::>(); let project_id = project.id; - let guest_user_id = session.user_id; + let guest_user_id = session.user_id(); let worktrees = project .worktrees @@ -1823,7 +1938,7 @@ fn join_project_internal( } /// Leave someone elses shared project. -async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { +async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> { let sender_id = session.connection_id; let project_id = ProjectId::from_proto(request.project_id); let db = session.db().await; @@ -1850,14 +1965,14 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result async fn join_hosted_project( request: proto::JoinHostedProject, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let (mut project, replica_id) = session .db() .await .join_hosted_project( ProjectId(request.project_id as i32), - session.user_id, + session.user_id(), session.connection_id, ) .await?; @@ -2168,7 +2283,7 @@ async fn broadcast_project_message_from_host, - session: Session, + session: UserSession, ) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); @@ -2203,7 +2318,7 @@ async fn follow( } /// Stop following another user in a call. -async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { +async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let project_id = request.project_id.map(ProjectId::from_proto); let leader_id = request @@ -2235,7 +2350,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { } /// Notify everyone following you of your current location. -async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { +async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> { let room_id = RoomId::from_proto(request.room_id); let database = session.db.lock().await; @@ -2297,7 +2412,7 @@ async fn get_users( async fn fuzzy_search_users( request: proto::FuzzySearchUsers, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let query = request.query; let users = match query.len() { @@ -2313,7 +2428,7 @@ async fn fuzzy_search_users( }; let users = users .into_iter() - .filter(|user| user.id != session.user_id) + .filter(|user| user.id != session.user_id()) .map(|user| proto::User { id: user.id.to_proto(), avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), @@ -2328,9 +2443,9 @@ async fn fuzzy_search_users( async fn request_contact( request: proto::RequestContact, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { - let requester_id = session.user_id; + let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.responder_id); if requester_id == responder_id { return Err(anyhow!("cannot add yourself as a contact"))?; @@ -2375,9 +2490,9 @@ async fn request_contact( async fn respond_to_contact_request( request: proto::RespondToContactRequest, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { - let responder_id = session.user_id; + let responder_id = session.user_id(); let requester_id = UserId::from_proto(request.requester_id); let db = session.db().await; if request.response == proto::ContactRequestResponse::Dismiss as i32 { @@ -2433,9 +2548,9 @@ async fn respond_to_contact_request( async fn remove_contact( request: proto::RemoveContact, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { - let requester_id = session.user_id; + let requester_id = session.user_id(); let responder_id = UserId::from_proto(request.user_id); let db = session.db().await; let (contact_accepted, deleted_notification_id) = @@ -2484,13 +2599,13 @@ async fn remove_contact( async fn create_channel( request: proto::CreateChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id)); let (channel, membership) = db - .create_channel(&request.name, parent_id, session.user_id) + .create_channel(&request.name, parent_id, session.user_id()) .await?; let root_id = channel.root_id(); @@ -2539,13 +2654,13 @@ async fn create_channel( async fn delete_channel( request: proto::DeleteChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = request.channel_id; let (root_channel, removed_channels) = db - .delete_channel(ChannelId::from_proto(channel_id), session.user_id) + .delete_channel(ChannelId::from_proto(channel_id), session.user_id()) .await?; response.send(proto::Ack {})?; @@ -2567,7 +2682,7 @@ async fn delete_channel( async fn invite_channel_member( request: proto::InviteChannelMember, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -2579,7 +2694,7 @@ async fn invite_channel_member( .invite_channel_member( channel_id, invitee_id, - session.user_id, + session.user_id(), request.role().into(), ) .await?; @@ -2604,7 +2719,7 @@ async fn invite_channel_member( async fn remove_channel_member( request: proto::RemoveChannelMember, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -2614,7 +2729,7 @@ async fn remove_channel_member( membership_update, notification_id, } = db - .remove_channel_member(channel_id, member_id, session.user_id) + .remove_channel_member(channel_id, member_id, session.user_id()) .await?; let mut connection_pool = session.connection_pool().await; @@ -2648,14 +2763,14 @@ async fn remove_channel_member( async fn set_channel_visibility( request: proto::SetChannelVisibility, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let visibility = request.visibility().into(); let channel_model = db - .set_channel_visibility(channel_id, visibility, session.user_id) + .set_channel_visibility(channel_id, visibility, session.user_id()) .await?; let root_id = channel_model.root_id(); let channel = Channel::from_model(channel_model); @@ -2693,7 +2808,7 @@ async fn set_channel_visibility( async fn set_channel_member_role( request: proto::SetChannelMemberRole, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -2701,7 +2816,7 @@ async fn set_channel_member_role( let result = db .set_channel_member_role( channel_id, - session.user_id, + session.user_id(), member_id, request.role().into(), ) @@ -2741,12 +2856,12 @@ async fn set_channel_member_role( async fn rename_channel( request: proto::RenameChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let channel_model = db - .rename_channel(channel_id, session.user_id, &request.name) + .rename_channel(channel_id, session.user_id(), &request.name) .await?; let root_id = channel_model.root_id(); let channel = Channel::from_model(channel_model); @@ -2773,7 +2888,7 @@ async fn rename_channel( async fn move_channel( request: proto::MoveChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let to = ChannelId::from_proto(request.to); @@ -2781,7 +2896,7 @@ async fn move_channel( let (root_id, channels) = session .db() .await - .move_channel(channel_id, to, session.user_id) + .move_channel(channel_id, to, session.user_id()) .await?; let connection_pool = session.connection_pool().await; @@ -2816,12 +2931,12 @@ async fn move_channel( async fn get_channel_members( request: proto::GetChannelMembers, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let members = db - .get_channel_participant_details(channel_id, session.user_id) + .get_channel_participant_details(channel_id, session.user_id()) .await?; response.send(proto::GetChannelMembersResponse { members })?; Ok(()) @@ -2831,7 +2946,7 @@ async fn get_channel_members( async fn respond_to_channel_invite( request: proto::RespondToChannelInvite, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -2839,7 +2954,7 @@ async fn respond_to_channel_invite( membership_update, notifications, } = db - .respond_to_channel_invite(channel_id, session.user_id, request.accept) + .respond_to_channel_invite(channel_id, session.user_id(), request.accept) .await?; let mut connection_pool = session.connection_pool().await; @@ -2847,7 +2962,7 @@ async fn respond_to_channel_invite( notify_membership_updated( &mut connection_pool, membership_update, - session.user_id, + session.user_id(), &session.peer, ); } else { @@ -2856,7 +2971,7 @@ async fn respond_to_channel_invite( ..Default::default() }; - for connection_id in connection_pool.user_connection_ids(session.user_id) { + for connection_id in connection_pool.user_connection_ids(session.user_id()) { session.peer.send(connection_id, update.clone())?; } }; @@ -2872,7 +2987,7 @@ async fn respond_to_channel_invite( async fn join_channel( request: proto::JoinChannel, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); join_channel_internal(channel_id, Box::new(response), session).await @@ -2895,14 +3010,14 @@ impl JoinChannelInternalResponse for Response { async fn join_channel_internal( channel_id: ChannelId, response: Box, - session: Session, + session: UserSession, ) -> Result<()> { let joined_room = { leave_room_for_session(&session).await?; let db = session.db().await; let (joined_room, membership_updated, role) = db - .join_channel(channel_id, session.user_id, session.connection_id) + .join_channel(channel_id, session.user_id(), session.connection_id) .await?; let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| { @@ -2912,7 +3027,7 @@ async fn join_channel_internal( live_kit .guest_token( &joined_room.room.live_kit_room, - &session.user_id.to_string(), + &session.user_id().to_string(), ) .trace_err()?, ) @@ -2922,7 +3037,7 @@ async fn join_channel_internal( live_kit .room_token( &joined_room.room.live_kit_room, - &session.user_id.to_string(), + &session.user_id().to_string(), ) .trace_err()?, ) @@ -2949,7 +3064,7 @@ async fn join_channel_internal( notify_membership_updated( &mut connection_pool, membership_updated, - session.user_id, + session.user_id(), &session.peer, ); } @@ -2968,7 +3083,7 @@ async fn join_channel_internal( &*session.connection_pool().await, ); - update_user_contacts(session.user_id, &session).await?; + update_user_contacts(session.user_id(), &session).await?; Ok(()) } @@ -2976,13 +3091,13 @@ async fn join_channel_internal( async fn join_channel_buffer( request: proto::JoinChannelBuffer, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let open_response = db - .join_channel_buffer(channel_id, session.user_id, session.connection_id) + .join_channel_buffer(channel_id, session.user_id(), session.connection_id) .await?; let collaborators = open_response.collaborators.clone(); @@ -3007,13 +3122,13 @@ async fn join_channel_buffer( /// Edit the channel notes async fn update_channel_buffer( request: proto::UpdateChannelBuffer, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let (collaborators, non_collaborators, epoch, version) = db - .update_channel_buffer(channel_id, session.user_id, &request.operations) + .update_channel_buffer(channel_id, session.user_id(), &request.operations) .await?; channel_buffer_updated( @@ -3055,11 +3170,11 @@ async fn update_channel_buffer( async fn rejoin_channel_buffers( request: proto::RejoinChannelBuffers, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let buffers = db - .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id) + .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id) .await?; for rejoined_buffer in &buffers { @@ -3090,7 +3205,7 @@ async fn rejoin_channel_buffers( async fn leave_channel_buffer( request: proto::LeaveChannelBuffer, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); @@ -3152,7 +3267,7 @@ fn send_notifications( async fn send_channel_message( request: proto::SendChannelMessage, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { // Validate the message body. let body = request.body.trim().to_string(); @@ -3181,7 +3296,7 @@ async fn send_channel_message( .await .create_channel_message( channel_id, - session.user_id, + session.user_id(), &body, &request.mentions, timestamp, @@ -3194,7 +3309,7 @@ async fn send_channel_message( .await?; let message = proto::ChannelMessage { - sender_id: session.user_id.to_proto(), + sender_id: session.user_id().to_proto(), id: message_id.to_proto(), body, mentions: request.mentions, @@ -3248,14 +3363,14 @@ async fn send_channel_message( async fn remove_channel_message( request: proto::RemoveChannelMessage, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); let connection_ids = session .db() .await - .remove_channel_message(channel_id, message_id, session.user_id) + .remove_channel_message(channel_id, message_id, session.user_id()) .await?; broadcast(Some(session.connection_id), connection_ids, |connection| { session.peer.send(connection, request.clone()) @@ -3267,7 +3382,7 @@ async fn remove_channel_message( async fn update_channel_message( request: proto::UpdateChannelMessage, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); @@ -3284,7 +3399,7 @@ async fn update_channel_message( .update_channel_message( channel_id, message_id, - session.user_id, + session.user_id(), request.body.as_str(), &request.mentions, updated_at, @@ -3297,7 +3412,7 @@ async fn update_channel_message( .ok_or_else(|| anyhow!("nonce can't be blank"))?; let message = proto::ChannelMessage { - sender_id: session.user_id.to_proto(), + sender_id: session.user_id().to_proto(), id: message_id.to_proto(), body: request.body.clone(), mentions: request.mentions.clone(), @@ -3332,14 +3447,14 @@ async fn update_channel_message( /// Mark a channel message as read async fn acknowledge_channel_message( request: proto::AckChannelMessage, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); let notifications = session .db() .await - .observe_channel_message(channel_id, session.user_id, message_id) + .observe_channel_message(channel_id, session.user_id(), message_id) .await?; send_notifications( &*session.connection_pool().await, @@ -3352,7 +3467,7 @@ async fn acknowledge_channel_message( /// Mark a buffer version as synced async fn acknowledge_buffer_version( request: proto::AckBufferOperation, - session: Session, + session: UserSession, ) -> Result<()> { let buffer_id = BufferId::from_proto(request.buffer_id); session @@ -3360,7 +3475,7 @@ async fn acknowledge_buffer_version( .await .observe_buffer_version( buffer_id, - session.user_id, + session.user_id(), request.epoch as i32, &request.version, ) @@ -3394,10 +3509,13 @@ async fn complete_with_language_model( open_ai_api_key: Option>, google_ai_api_key: Option>, ) -> Result<()> { + let Some(session) = session.for_user() else { + return Err(anyhow!("user not found"))?; + }; authorize_access_to_language_models(&session).await?; session .rate_limiter - .check::(session.user_id) + .check::(session.user_id()) .await?; if request.model.starts_with("gpt") { @@ -3416,7 +3534,7 @@ async fn complete_with_language_model( async fn complete_with_open_ai( request: proto::CompleteWithLanguageModel, response: StreamingResponse, - session: Session, + session: UserSession, api_key: Arc, ) -> Result<()> { const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; @@ -3458,7 +3576,7 @@ async fn complete_with_open_ai( async fn complete_with_google_ai( request: proto::CompleteWithLanguageModel, response: StreamingResponse, - session: Session, + session: UserSession, api_key: Arc, ) -> Result<()> { let mut stream = google_ai::stream_generate_content( @@ -3527,7 +3645,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit { async fn count_tokens_with_language_model( request: proto::CountTokensWithLanguageModel, response: Response, - session: Session, + session: UserSession, google_ai_api_key: Option>, ) -> Result<()> { authorize_access_to_language_models(&session).await?; @@ -3541,7 +3659,7 @@ async fn count_tokens_with_language_model( session .rate_limiter - .check::(session.user_id) + .check::(session.user_id()) .await?; let api_key = google_ai_api_key @@ -3559,9 +3677,9 @@ async fn count_tokens_with_language_model( Ok(()) } -async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> { +async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> { let db = session.db().await; - let flags = db.get_user_flags(session.user_id).await?; + let flags = db.get_user_flags(session.user_id()).await?; if flags.iter().any(|flag| flag == "language-models") { Ok(()) } else { @@ -3573,15 +3691,15 @@ async fn authorize_access_to_language_models(session: &Session) -> Result<(), Er async fn join_channel_chat( request: proto::JoinChannelChat, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let db = session.db().await; - db.join_channel_chat(channel_id, session.connection_id, session.user_id) + db.join_channel_chat(channel_id, session.connection_id, session.user_id()) .await?; let messages = db - .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None) + .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None) .await?; response.send(proto::JoinChannelChatResponse { done: messages.len() < MESSAGE_COUNT_PER_PAGE, @@ -3591,12 +3709,12 @@ async fn join_channel_chat( } /// Stop receiving chat updates for a channel -async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> { +async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); session .db() .await - .leave_channel_chat(channel_id, session.connection_id, session.user_id) + .leave_channel_chat(channel_id, session.connection_id, session.user_id()) .await?; Ok(()) } @@ -3605,7 +3723,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) async fn get_channel_messages( request: proto::GetChannelMessages, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let messages = session @@ -3613,7 +3731,7 @@ async fn get_channel_messages( .await .get_channel_messages( channel_id, - session.user_id, + session.user_id(), MESSAGE_COUNT_PER_PAGE, Some(MessageId::from_proto(request.before_message_id)), ) @@ -3629,7 +3747,7 @@ async fn get_channel_messages( async fn get_channel_messages_by_id( request: proto::GetChannelMessagesById, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let message_ids = request .message_ids @@ -3639,7 +3757,7 @@ async fn get_channel_messages_by_id( let messages = session .db() .await - .get_channel_messages_by_id(session.user_id, &message_ids) + .get_channel_messages_by_id(session.user_id(), &message_ids) .await?; response.send(proto::GetChannelMessagesResponse { done: messages.len() < MESSAGE_COUNT_PER_PAGE, @@ -3652,13 +3770,13 @@ async fn get_channel_messages_by_id( async fn get_notifications( request: proto::GetNotifications, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let notifications = session .db() .await .get_notifications( - session.user_id, + session.user_id(), NOTIFICATION_COUNT_PER_PAGE, request .before_id @@ -3676,12 +3794,12 @@ async fn get_notifications( async fn mark_notification_as_read( request: proto::MarkNotificationRead, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let database = &session.db().await; let notifications = database .mark_notification_as_read_by_id( - session.user_id, + session.user_id(), NotificationId::from_proto(request.notification_id), ) .await?; @@ -3698,16 +3816,16 @@ async fn mark_notification_as_read( async fn get_private_user_info( _request: proto::GetPrivateUserInfo, response: Response, - session: Session, + session: UserSession, ) -> Result<()> { let db = session.db().await; - let metrics_id = db.get_user_metrics_id(session.user_id).await?; + let metrics_id = db.get_user_metrics_id(session.user_id()).await?; let user = db - .get_user_by_id(session.user_id) + .get_user_by_id(session.user_id()) .await? .ok_or_else(|| anyhow!("user not found"))?; - let flags = db.get_user_flags(session.user_id).await?; + let flags = db.get_user_flags(session.user_id()).await?; response.send(proto::GetPrivateUserInfoResponse { metrics_id, @@ -3951,7 +4069,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> Ok(()) } -async fn leave_room_for_session(session: &Session) -> Result<()> { +async fn leave_room_for_session(session: &UserSession) -> Result<()> { let mut contacts_to_update = HashSet::default(); let room_id; @@ -3962,7 +4080,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { let channel; if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? { - contacts_to_update.insert(session.user_id); + contacts_to_update.insert(session.user_id()); for project in left_room.left_projects.values() { project_left(project, session); @@ -4013,7 +4131,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { if let Some(live_kit) = session.live_kit_client.as_ref() { live_kit - .remove_participant(live_kit_room.clone(), session.user_id.to_string()) + .remove_participant(live_kit_room.clone(), session.user_id().to_string()) .await .trace_err(); @@ -4047,9 +4165,9 @@ async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> { Ok(()) } -fn project_left(project: &db::LeftProject, session: &Session) { +fn project_left(project: &db::LeftProject, session: &UserSession) { for connection_id in &project.connection_ids { - if project.host_user_id == Some(session.user_id) { + if project.host_user_id == Some(session.user_id()) { session .peer .send( diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index e5ca052a2f..3027848b2b 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -1,7 +1,7 @@ use crate::{ db::{tests::TestDb, NewUserParams, UserId}, executor::Executor, - rpc::{Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, + rpc::{Principal, Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, AppState, Config, RateLimiter, }; use anyhow::anyhow; @@ -197,15 +197,20 @@ impl TestServer { .override_authenticate(move |cx| { cx.spawn(|_| async move { let access_token = "the-token".to_string(); - Ok(Credentials { + Ok(Credentials::User { user_id: user_id.to_proto(), access_token, }) }) }) .override_establish_connection(move |credentials, cx| { - assert_eq!(credentials.user_id, user_id.0 as u64); - assert_eq!(credentials.access_token, "the-token"); + assert_eq!( + credentials, + &Credentials::User { + user_id: user_id.0 as u64, + access_token: "the-token".into() + } + ); let server = server.clone(); let db = db.clone(); @@ -230,9 +235,8 @@ impl TestServer { .spawn(server.handle_connection( server_conn, client_name, - user, + Principal::User(user), ZedVersion(SemanticVersion::new(1, 0, 0)), - None, Some(connection_id_tx), Executor::Deterministic(cx.background_executor().clone()), )) diff --git a/crates/theme/theme.md b/crates/theme/theme.md index f9a7a58178..d19d147597 100644 --- a/crates/theme/theme.md +++ b/crates/theme/theme.md @@ -1,15 +1,15 @@ - # Theme +# Theme - This crate provides the theme system for Zed. +This crate provides the theme system for Zed. - ## Overview +## Overview - A theme is a collection of colors used to build a consistent appearance for UI components across the application. - To produce a theme in Zed, +A theme is a collection of colors used to build a consistent appearance for UI components across the application. +To produce a theme in Zed, - A theme is made of of two parts: A [ThemeFamily] and one or more [Theme]s. +A theme is made of of two parts: A [ThemeFamily] and one or more [Theme]s. // - A [ThemeFamily] contains metadata like theme name, author, and theme-specific [ColorScales] as well as a series of themes. +A [ThemeFamily] contains metadata like theme name, author, and theme-specific [ColorScales] as well as a series of themes. - - [ThemeColors] - A set of colors that are used to style the UI. Refer to the [ThemeColors] documentation for more information. +- [ThemeColors] - A set of colors that are used to style the UI. Refer to the [ThemeColors] documentation for more information. diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index b5cf2d9e54..8195a2bf93 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -26,6 +26,7 @@ breadcrumbs.workspace = true call.workspace = true channel.workspace = true chrono.workspace = true +clap.workspace = true cli.workspace = true client.workspace = true clock.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index bb6a2ca6e8..a5e959d757 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -6,8 +6,9 @@ mod zed; use anyhow::{anyhow, Context as _, Result}; use backtrace::Backtrace; use chrono::Utc; +use clap::{command, Parser}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; -use client::{parse_zed_link, Client, UserStore}; +use client::{parse_zed_link, Client, ClientSettings, DevServerToken, UserStore}; use collab_ui::channel_view::ChannelView; use db::kvp::KEY_VALUE_STORE; use editor::Editor; @@ -270,9 +271,28 @@ fn main() { cx.activate(true); - let urls = collect_url_args(cx); - if !urls.is_empty() { - listener.open_urls(urls) + let mut args = Args::parse(); + if let Some(dev_server_token) = args.dev_server_token.take() { + let dev_server_token = DevServerToken(dev_server_token); + let server_url = ClientSettings::get_global(&cx).server_url.clone(); + let client = client.clone(); + client.set_dev_server_token(dev_server_token); + cx.spawn(|cx| async move { + client.authenticate_and_connect(false, &cx).await?; + log::info!("Connected to {}", server_url); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } else { + let urls: Vec<_> = args + .paths_or_urls + .iter() + .filter_map(|arg| parse_url_arg(arg, cx).log_err()) + .collect(); + + if !urls.is_empty() { + listener.open_urls(urls) + } } let mut triggered_authentication = false; @@ -898,23 +918,35 @@ fn stdout_is_a_pty() -> bool { std::env::var(FORCE_CLI_MODE_ENV_VAR_NAME).ok().is_none() && std::io::stdout().is_terminal() } -fn collect_url_args(cx: &AppContext) -> Vec { - env::args() - .skip(1) - .filter_map(|arg| match std::fs::canonicalize(Path::new(&arg)) { - Ok(path) => Some(format!("file://{}", path.to_string_lossy())), - Err(error) => { - if arg.starts_with("file://") || arg.starts_with("zed-cli://") { - Some(arg) - } else if let Some(_) = parse_zed_link(&arg, cx) { - Some(arg) - } else { - log::error!("error parsing path argument: {}", error); - None - } +#[derive(Parser, Debug)] +#[command(name = "zed", disable_version_flag = true)] +struct Args { + /// A sequence of space-separated paths or urls that you want to open. + /// + /// Use `path:line:row` syntax to open a file at a specific location. + /// Non-existing paths and directories will ignore `:line:row` suffix. + /// + /// URLs can either be file:// or zed:// scheme, or relative to https://zed.dev. + paths_or_urls: Vec, + + /// Instructs zed to run as a dev server on this machine. (not implemented) + #[arg(long)] + dev_server_token: Option, +} + +fn parse_url_arg(arg: &str, cx: &AppContext) -> Result { + match std::fs::canonicalize(Path::new(&arg)) { + Ok(path) => Ok(format!("file://{}", path.to_string_lossy())), + Err(error) => { + if arg.starts_with("file://") || arg.starts_with("zed-cli://") { + Ok(arg.into()) + } else if let Some(_) = parse_zed_link(&arg, cx) { + Ok(arg.into()) + } else { + Err(anyhow!("error parsing path argument: {}", error)) } - }) - .collect() + } + } } fn load_embedded_fonts(cx: &AppContext) { diff --git a/script/create-migration b/script/create-migration new file mode 100755 index 0000000000..187336be19 --- /dev/null +++ b/script/create-migration @@ -0,0 +1,3 @@ +zed . \ + "crates/collab/migrations.sqlite/20221109000000_test_schema.sql" \ + "crates/collab/migrations/$(date -u +%Y%m%d%H%M%S)_$(echo $1 | sed 's/[^a-z0-9]/_/g').sql" diff --git a/script/eula/eula.rtf b/script/eula/eula.rtf index 3bdbb463bb..6feaff789c 100644 --- a/script/eula/eula.rtf +++ b/script/eula/eula.rtf @@ -182,4 +182,4 @@ \f0\b \cf2 DATE: April 5, 2023 \f1\b0 \ -} \ No newline at end of file +}