diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 68f5299a8b..e0d65e8969 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -450,8 +450,8 @@ mod tests { #[gpui::test] async fn test_channel_messages(mut cx: TestAppContext) { let user_id = 5; - let client = Client::new(); - let server = FakeServer::for_client(user_id, &client, &cx).await; + let mut client = Client::new(); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; let user_store = Arc::new(UserStore::new(client.clone())); let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx)); diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 64fc8a56ea..81cc10395f 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -28,12 +28,21 @@ lazy_static! { pub struct Client { peer: Arc, state: RwLock, + auth_callback: Option< + Box Task>>, + >, + connect_callback: Option< + Box Task>>, + >, } #[derive(Copy, Clone, Debug)] pub enum Status { Disconnected, - Connecting, + Authenticating, + Connecting { + user_id: u64, + }, ConnectionError, Connected { connection_id: ConnectionId, @@ -94,9 +103,24 @@ impl Client { Arc::new(Self { peer: Peer::new(), state: Default::default(), + auth_callback: None, + connect_callback: None, }) } + #[cfg(any(test, feature = "test-support"))] + pub fn set_login_and_connect_callbacks( + &mut self, + login: Login, + connect: Connect, + ) where + Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task>, + Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task>, + { + self.auth_callback = Some(Box::new(login)); + self.connect_callback = Some(Box::new(connect)); + } + pub fn status(&self) -> watch::Receiver { self.state.read().status.1.clone() } @@ -192,11 +216,13 @@ impl Client { ) -> anyhow::Result<()> { if matches!( *self.status().borrow(), - Status::Connecting { .. } | Status::Connected { .. } + Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. } ) { return Ok(()); } + self.set_status(Status::Authenticating, cx); + let (user_id, access_token) = match self.authenticate(&cx).await { Ok(result) => result, Err(err) => { @@ -205,7 +231,7 @@ impl Client { } }; - self.set_status(Status::Connecting, cx); + self.set_status(Status::Connecting { user_id }, cx); let conn = match self.connect(user_id, &access_token, cx).await { Ok(conn) => conn, @@ -285,11 +311,32 @@ impl Client { Ok(()) } + fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { + if let Some(callback) = self.auth_callback.as_ref() { + callback(cx) + } else { + self.authenticate_with_browser(cx) + } + } + fn connect( self: &Arc, user_id: u64, access_token: &str, cx: &AsyncAppContext, + ) -> Task> { + if let Some(callback) = self.connect_callback.as_ref() { + callback(user_id, access_token, cx) + } else { + self.connect_with_websocket(user_id, access_token, cx) + } + } + + fn connect_with_websocket( + self: &Arc, + user_id: u64, + access_token: &str, + cx: &AsyncAppContext, ) -> Task> { let request = Request::builder().header("Authorization", format!("{} {}", user_id, access_token)); @@ -314,7 +361,10 @@ impl Client { }) } - pub fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { + pub fn authenticate_with_browser( + self: &Arc, + cx: &AsyncAppContext, + ) -> Task> { let platform = cx.platform(); let executor = cx.background(); executor.clone().spawn(async move { @@ -488,8 +538,8 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_heartbeat(cx: TestAppContext) { let user_id = 5; - let client = Client::new(); - let server = FakeServer::for_client(user_id, &client, &cx).await; + let mut client = Client::new(); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; cx.foreground().advance_clock(Duration::from_secs(10)); let ping = server.receive::().await.unwrap(); diff --git a/zed/src/test.rs b/zed/src/test.rs index e5169ecb69..bee1537b9d 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -203,16 +203,42 @@ pub struct FakeServer { } impl FakeServer { - pub async fn for_client(user_id: u64, client: &Arc, cx: &TestAppContext) -> Arc { + pub async fn for_client( + client_user_id: u64, + client: &mut Arc, + cx: &TestAppContext, + ) -> Arc { let result = Arc::new(Self { peer: Peer::new(), incoming: Default::default(), connection_id: Default::default(), }); + Arc::get_mut(client) + .unwrap() + .set_login_and_connect_callbacks( + move |cx| { + cx.spawn(|_| async move { + let access_token = "the-token".to_string(); + Ok((client_user_id, access_token)) + }) + }, + { + let server = result.clone(); + move |user_id, access_token, cx| { + assert_eq!(user_id, client_user_id); + assert_eq!(access_token, "the-token"); + cx.spawn({ + let server = server.clone(); + move |cx| async move { Ok(server.connect(&cx).await) } + }) + } + }, + ); + let conn = result.connect(&cx.to_async()).await; client - .set_connection(user_id, conn, &cx.to_async()) + .set_connection(client_user_id, conn, &cx.to_async()) .await .unwrap(); result