Don't trigger authentication flow unless credentials expired (#35570)
This fixes a regression introduced in https://github.com/zed-industries/zed/pull/35471, where we treated stored credentials as invalid when failing to retrieve the authenticated user for any reason. This had the side effect of triggering the auth flow even when e.g. the client/server had temporary networking issues. This pull request changes the logic to only trigger authentication when getting a 401 from the server. Release Notes: - N/A
This commit is contained in:
parent
5ca5d90234
commit
7217439c97
2 changed files with 137 additions and 34 deletions
|
@ -884,27 +884,28 @@ impl Client {
|
||||||
|
|
||||||
let old_credentials = self.state.read().credentials.clone();
|
let old_credentials = self.state.read().credentials.clone();
|
||||||
if let Some(old_credentials) = old_credentials {
|
if let Some(old_credentials) = old_credentials {
|
||||||
self.cloud_client.set_credentials(
|
if self
|
||||||
old_credentials.user_id as u32,
|
.cloud_client
|
||||||
old_credentials.access_token.clone(),
|
.validate_credentials(
|
||||||
);
|
old_credentials.user_id as u32,
|
||||||
|
&old_credentials.access_token,
|
||||||
// Fetch the authenticated user with the old credentials, to ensure they are still valid.
|
)
|
||||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
.await?
|
||||||
|
{
|
||||||
credentials = Some(old_credentials);
|
credentials = Some(old_credentials);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if credentials.is_none() && try_provider {
|
if credentials.is_none() && try_provider {
|
||||||
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
|
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
|
||||||
self.cloud_client.set_credentials(
|
if self
|
||||||
stored_credentials.user_id as u32,
|
.cloud_client
|
||||||
stored_credentials.access_token.clone(),
|
.validate_credentials(
|
||||||
);
|
stored_credentials.user_id as u32,
|
||||||
|
&stored_credentials.access_token,
|
||||||
// Fetch the authenticated user with the stored credentials, and
|
)
|
||||||
// clear them from the credentials provider if that fails.
|
.await?
|
||||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
{
|
||||||
credentials = Some(stored_credentials);
|
credentials = Some(stored_credentials);
|
||||||
} else {
|
} else {
|
||||||
self.credentials_provider
|
self.credentials_provider
|
||||||
|
@ -1709,7 +1710,7 @@ pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::test::FakeServer;
|
use crate::test::{FakeServer, parse_authorization_header};
|
||||||
|
|
||||||
use clock::FakeSystemClock;
|
use clock::FakeSystemClock;
|
||||||
use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
|
use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
|
||||||
|
@ -1835,6 +1836,75 @@ mod tests {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test(iterations = 10)]
|
||||||
|
async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
let auth_count = Arc::new(Mutex::new(0));
|
||||||
|
let http_client = FakeHttpClient::create(|_request| async move {
|
||||||
|
Ok(http_client::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.body("".into())
|
||||||
|
.unwrap())
|
||||||
|
});
|
||||||
|
let client =
|
||||||
|
cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
|
||||||
|
client.override_authenticate({
|
||||||
|
let auth_count = auth_count.clone();
|
||||||
|
move |cx| {
|
||||||
|
let auth_count = auth_count.clone();
|
||||||
|
cx.background_spawn(async move {
|
||||||
|
*auth_count.lock() += 1;
|
||||||
|
Ok(Credentials {
|
||||||
|
user_id: 1,
|
||||||
|
access_token: auth_count.lock().to_string(),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
|
||||||
|
assert_eq!(*auth_count.lock(), 1);
|
||||||
|
assert_eq!(credentials.access_token, "1");
|
||||||
|
|
||||||
|
// If credentials are still valid, signing in doesn't trigger authentication.
|
||||||
|
let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
|
||||||
|
assert_eq!(*auth_count.lock(), 1);
|
||||||
|
assert_eq!(credentials.access_token, "1");
|
||||||
|
|
||||||
|
// If the server is unavailable, signing in doesn't trigger authentication.
|
||||||
|
http_client
|
||||||
|
.as_fake()
|
||||||
|
.replace_handler(|_, _request| async move {
|
||||||
|
Ok(http_client::Response::builder()
|
||||||
|
.status(503)
|
||||||
|
.body("".into())
|
||||||
|
.unwrap())
|
||||||
|
});
|
||||||
|
client.sign_in(false, &cx.to_async()).await.unwrap_err();
|
||||||
|
assert_eq!(*auth_count.lock(), 1);
|
||||||
|
|
||||||
|
// If credentials became invalid, signing in triggers authentication.
|
||||||
|
http_client
|
||||||
|
.as_fake()
|
||||||
|
.replace_handler(|_, request| async move {
|
||||||
|
let credentials = parse_authorization_header(&request).unwrap();
|
||||||
|
if credentials.access_token == "2" {
|
||||||
|
Ok(http_client::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.body("".into())
|
||||||
|
.unwrap())
|
||||||
|
} else {
|
||||||
|
Ok(http_client::Response::builder()
|
||||||
|
.status(401)
|
||||||
|
.body("".into())
|
||||||
|
.unwrap())
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
|
||||||
|
assert_eq!(*auth_count.lock(), 2);
|
||||||
|
assert_eq!(credentials.access_token, "2");
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test(iterations = 10)]
|
#[gpui::test(iterations = 10)]
|
||||||
async fn test_authenticating_more_than_once(
|
async fn test_authenticating_more_than_once(
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Context, Result, anyhow};
|
||||||
pub use cloud_api_types::*;
|
pub use cloud_api_types::*;
|
||||||
use futures::AsyncReadExt as _;
|
use futures::AsyncReadExt as _;
|
||||||
use http_client::http::request;
|
use http_client::http::request;
|
||||||
use http_client::{AsyncBody, HttpClientWithUrl, Method, Request};
|
use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
|
|
||||||
struct Credentials {
|
struct Credentials {
|
||||||
|
@ -40,27 +40,14 @@ impl CloudApiClient {
|
||||||
*self.credentials.write() = None;
|
*self.credentials.write() = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authorization_header(&self) -> Result<String> {
|
|
||||||
let guard = self.credentials.read();
|
|
||||||
let credentials = guard
|
|
||||||
.as_ref()
|
|
||||||
.ok_or_else(|| anyhow!("No credentials provided"))?;
|
|
||||||
|
|
||||||
Ok(format!(
|
|
||||||
"{} {}",
|
|
||||||
credentials.user_id, credentials.access_token
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_request(
|
fn build_request(
|
||||||
&self,
|
&self,
|
||||||
req: request::Builder,
|
req: request::Builder,
|
||||||
body: impl Into<AsyncBody>,
|
body: impl Into<AsyncBody>,
|
||||||
) -> Result<Request<AsyncBody>> {
|
) -> Result<Request<AsyncBody>> {
|
||||||
Ok(req
|
let credentials = self.credentials.read();
|
||||||
.header("Content-Type", "application/json")
|
let credentials = credentials.as_ref().context("no credentials provided")?;
|
||||||
.header("Authorization", self.authorization_header()?)
|
build_request(req, body, credentials)
|
||||||
.body(body.into())?)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_authenticated_user(&self) -> Result<GetAuthenticatedUserResponse> {
|
pub async fn get_authenticated_user(&self) -> Result<GetAuthenticatedUserResponse> {
|
||||||
|
@ -152,4 +139,50 @@ impl CloudApiClient {
|
||||||
|
|
||||||
Ok(serde_json::from_str(&body)?)
|
Ok(serde_json::from_str(&body)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn validate_credentials(&self, user_id: u32, access_token: &str) -> Result<bool> {
|
||||||
|
let request = build_request(
|
||||||
|
Request::builder().method(Method::GET).uri(
|
||||||
|
self.http_client
|
||||||
|
.build_zed_cloud_url("/client/users/me", &[])?
|
||||||
|
.as_ref(),
|
||||||
|
),
|
||||||
|
AsyncBody::default(),
|
||||||
|
&Credentials {
|
||||||
|
user_id,
|
||||||
|
access_token: access_token.into(),
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut response = self.http_client.send(request).await?;
|
||||||
|
|
||||||
|
if response.status().is_success() {
|
||||||
|
Ok(true)
|
||||||
|
} else {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
if response.status() == StatusCode::UNAUTHORIZED {
|
||||||
|
return Ok(false);
|
||||||
|
} else {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Failed to get authenticated user.\nStatus: {:?}\nBody: {body}",
|
||||||
|
response.status()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_request(
|
||||||
|
req: request::Builder,
|
||||||
|
body: impl Into<AsyncBody>,
|
||||||
|
credentials: &Credentials,
|
||||||
|
) -> Result<Request<AsyncBody>> {
|
||||||
|
Ok(req
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header(
|
||||||
|
"Authorization",
|
||||||
|
format!("{} {}", credentials.user_id, credentials.access_token),
|
||||||
|
)
|
||||||
|
.body(body.into())?)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue