diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index b9b20aa4f2..e6d8f10d12 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -884,27 +884,28 @@ impl Client { let old_credentials = self.state.read().credentials.clone(); if let Some(old_credentials) = old_credentials { - self.cloud_client.set_credentials( - old_credentials.user_id as u32, - old_credentials.access_token.clone(), - ); - - // Fetch the authenticated user with the old credentials, to ensure they are still valid. - if self.cloud_client.get_authenticated_user().await.is_ok() { + if self + .cloud_client + .validate_credentials( + old_credentials.user_id as u32, + &old_credentials.access_token, + ) + .await? + { credentials = Some(old_credentials); } } if credentials.is_none() && try_provider { if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await { - self.cloud_client.set_credentials( - stored_credentials.user_id as u32, - stored_credentials.access_token.clone(), - ); - - // Fetch the authenticated user with the stored credentials, and - // clear them from the credentials provider if that fails. - if self.cloud_client.get_authenticated_user().await.is_ok() { + if self + .cloud_client + .validate_credentials( + stored_credentials.user_id as u32, + &stored_credentials.access_token, + ) + .await? + { credentials = Some(stored_credentials); } else { self.credentials_provider @@ -1709,7 +1710,7 @@ pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> { #[cfg(test)] mod tests { use super::*; - use crate::test::FakeServer; + use crate::test::{FakeServer, parse_authorization_header}; use clock::FakeSystemClock; use gpui::{AppContext as _, BackgroundExecutor, TestAppContext}; @@ -1835,6 +1836,75 @@ mod tests { )); } + #[gpui::test(iterations = 10)] + async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) { + init_test(cx); + let auth_count = Arc::new(Mutex::new(0)); + let http_client = FakeHttpClient::create(|_request| async move { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + }); + let client = + cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx)); + client.override_authenticate({ + let auth_count = auth_count.clone(); + move |cx| { + let auth_count = auth_count.clone(); + cx.background_spawn(async move { + *auth_count.lock() += 1; + Ok(Credentials { + user_id: 1, + access_token: auth_count.lock().to_string(), + }) + }) + } + }); + + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(credentials.access_token, "1"); + + // If credentials are still valid, signing in doesn't trigger authentication. + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(credentials.access_token, "1"); + + // If the server is unavailable, signing in doesn't trigger authentication. + http_client + .as_fake() + .replace_handler(|_, _request| async move { + Ok(http_client::Response::builder() + .status(503) + .body("".into()) + .unwrap()) + }); + client.sign_in(false, &cx.to_async()).await.unwrap_err(); + assert_eq!(*auth_count.lock(), 1); + + // If credentials became invalid, signing in triggers authentication. + http_client + .as_fake() + .replace_handler(|_, request| async move { + let credentials = parse_authorization_header(&request).unwrap(); + if credentials.access_token == "2" { + Ok(http_client::Response::builder() + .status(200) + .body("".into()) + .unwrap()) + } else { + Ok(http_client::Response::builder() + .status(401) + .body("".into()) + .unwrap()) + } + }); + let credentials = client.sign_in(false, &cx.to_async()).await.unwrap(); + assert_eq!(*auth_count.lock(), 2); + assert_eq!(credentials.access_token, "2"); + } + #[gpui::test(iterations = 10)] async fn test_authenticating_more_than_once( cx: &mut TestAppContext, diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index 6689475dae..edac051a0e 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use anyhow::{Result, anyhow}; +use anyhow::{Context, Result, anyhow}; pub use cloud_api_types::*; use futures::AsyncReadExt as _; use http_client::http::request; -use http_client::{AsyncBody, HttpClientWithUrl, Method, Request}; +use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode}; use parking_lot::RwLock; struct Credentials { @@ -40,27 +40,14 @@ impl CloudApiClient { *self.credentials.write() = None; } - fn authorization_header(&self) -> Result { - let guard = self.credentials.read(); - let credentials = guard - .as_ref() - .ok_or_else(|| anyhow!("No credentials provided"))?; - - Ok(format!( - "{} {}", - credentials.user_id, credentials.access_token - )) - } - fn build_request( &self, req: request::Builder, body: impl Into, ) -> Result> { - Ok(req - .header("Content-Type", "application/json") - .header("Authorization", self.authorization_header()?) - .body(body.into())?) + let credentials = self.credentials.read(); + let credentials = credentials.as_ref().context("no credentials provided")?; + build_request(req, body, credentials) } pub async fn get_authenticated_user(&self) -> Result { @@ -152,4 +139,50 @@ impl CloudApiClient { Ok(serde_json::from_str(&body)?) } + + pub async fn validate_credentials(&self, user_id: u32, access_token: &str) -> Result { + let request = build_request( + Request::builder().method(Method::GET).uri( + self.http_client + .build_zed_cloud_url("/client/users/me", &[])? + .as_ref(), + ), + AsyncBody::default(), + &Credentials { + user_id, + access_token: access_token.into(), + }, + )?; + + let mut response = self.http_client.send(request).await?; + + if response.status().is_success() { + Ok(true) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + if response.status() == StatusCode::UNAUTHORIZED { + return Ok(false); + } else { + return Err(anyhow!( + "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", + response.status() + )); + } + } + } +} + +fn build_request( + req: request::Builder, + body: impl Into, + credentials: &Credentials, +) -> Result> { + Ok(req + .header("Content-Type", "application/json") + .header( + "Authorization", + format!("{} {}", credentials.user_id, credentials.access_token), + ) + .body(body.into())?) }