Start separating authentication from connection to collab (#35471)
This pull request should be idempotent, but lays the groundwork for avoiding to connect to collab in order to interact with AI features provided by Zed. Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com> Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
b01d1872cc
commit
f888f3fc0b
46 changed files with 653 additions and 855 deletions
|
@ -1,14 +1,12 @@
|
|||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
||||
mod cloud;
|
||||
mod proxy;
|
||||
pub mod telemetry;
|
||||
pub mod user;
|
||||
pub mod zed_urls;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use async_recursion::async_recursion;
|
||||
use async_tungstenite::tungstenite::{
|
||||
client::IntoClientRequest,
|
||||
error::Error as WebsocketError,
|
||||
|
@ -52,7 +50,6 @@ use tokio::net::TcpStream;
|
|||
use url::Url;
|
||||
use util::{ConnectionResult, ResultExt};
|
||||
|
||||
pub use cloud::*;
|
||||
pub use rpc::*;
|
||||
pub use telemetry_events::Event;
|
||||
pub use user::*;
|
||||
|
@ -164,20 +161,8 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
|
|||
let client = client.clone();
|
||||
move |_: &SignIn, cx| {
|
||||
if let Some(client) = client.upgrade() {
|
||||
cx.spawn(
|
||||
async move |cx| match client.authenticate_and_connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("Initial authentication timed out");
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::error!("Initial authentication connection reset");
|
||||
}
|
||||
ConnectionResult::Result(r) => {
|
||||
r.log_err();
|
||||
}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -286,6 +271,8 @@ pub enum Status {
|
|||
SignedOut,
|
||||
UpgradeRequired,
|
||||
Authenticating,
|
||||
Authenticated,
|
||||
AuthenticationError,
|
||||
Connecting,
|
||||
ConnectionError,
|
||||
Connected {
|
||||
|
@ -712,7 +699,7 @@ impl Client {
|
|||
|
||||
let mut delay = INITIAL_RECONNECTION_DELAY;
|
||||
loop {
|
||||
match client.authenticate_and_connect(true, &cx).await {
|
||||
match client.connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("client connect attempt timed out")
|
||||
}
|
||||
|
@ -882,17 +869,122 @@ impl Client {
|
|||
.is_some()
|
||||
}
|
||||
|
||||
#[async_recursion(?Send)]
|
||||
pub async fn authenticate_and_connect(
|
||||
pub async fn sign_in(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Credentials> {
|
||||
if self.status().borrow().is_signed_out() {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx);
|
||||
}
|
||||
|
||||
let mut credentials = None;
|
||||
|
||||
let old_credentials = self.state.read().credentials.clone();
|
||||
if let Some(old_credentials) = old_credentials {
|
||||
self.cloud_client.set_credentials(
|
||||
old_credentials.user_id as u32,
|
||||
old_credentials.access_token.clone(),
|
||||
);
|
||||
|
||||
// Fetch the authenticated user with the old credentials, to ensure they are still valid.
|
||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
||||
credentials = Some(old_credentials);
|
||||
}
|
||||
}
|
||||
|
||||
if credentials.is_none() && try_provider {
|
||||
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
|
||||
self.cloud_client.set_credentials(
|
||||
stored_credentials.user_id as u32,
|
||||
stored_credentials.access_token.clone(),
|
||||
);
|
||||
|
||||
// Fetch the authenticated user with the stored credentials, and
|
||||
// clear them from the credentials provider if that fails.
|
||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
||||
credentials = Some(stored_credentials);
|
||||
} else {
|
||||
self.credentials_provider
|
||||
.delete_credentials(cx)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if credentials.is_none() {
|
||||
let mut status_rx = self.status();
|
||||
let _ = status_rx.next().await;
|
||||
futures::select_biased! {
|
||||
authenticate = self.authenticate(cx).fuse() => {
|
||||
match authenticate {
|
||||
Ok(creds) => {
|
||||
if IMPERSONATE_LOGIN.is_none() {
|
||||
self.credentials_provider
|
||||
.write_credentials(creds.user_id, creds.access_token.clone(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
|
||||
credentials = Some(creds);
|
||||
},
|
||||
Err(err) => {
|
||||
self.set_status(Status::AuthenticationError, cx);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = status_rx.next().fuse() => {
|
||||
return Err(anyhow!("authentication canceled"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let credentials = credentials.unwrap();
|
||||
self.set_id(credentials.user_id);
|
||||
self.cloud_client
|
||||
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
|
||||
self.state.write().credentials = Some(credentials.clone());
|
||||
self.set_status(Status::Authenticated, cx);
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
/// Performs a sign-in and also connects to Collab.
|
||||
///
|
||||
/// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls
|
||||
/// to `sign_in` when we're ready to remove auto-connection to Collab.
|
||||
pub async fn sign_in_with_optional_connect(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<()> {
|
||||
let credentials = self.sign_in(try_provider, cx).await?;
|
||||
|
||||
let connect_result = match self.connect_with_credentials(credentials, cx).await {
|
||||
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
|
||||
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
|
||||
ConnectionResult::Result(result) => result.context("client auth and connect"),
|
||||
};
|
||||
connect_result.log_err();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> ConnectionResult<()> {
|
||||
let was_disconnected = match *self.status().borrow() {
|
||||
Status::SignedOut => true,
|
||||
Status::SignedOut | Status::Authenticated => true,
|
||||
Status::ConnectionError
|
||||
| Status::ConnectionLost
|
||||
| Status::Authenticating { .. }
|
||||
| Status::AuthenticationError
|
||||
| Status::Reauthenticating { .. }
|
||||
| Status::ReconnectionError { .. } => false,
|
||||
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
|
||||
|
@ -905,41 +997,10 @@ impl Client {
|
|||
);
|
||||
}
|
||||
};
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx)
|
||||
}
|
||||
|
||||
let mut read_from_provider = false;
|
||||
let mut credentials = self.state.read().credentials.clone();
|
||||
if credentials.is_none() && try_provider {
|
||||
credentials = self.credentials_provider.read_credentials(cx).await;
|
||||
read_from_provider = credentials.is_some();
|
||||
}
|
||||
|
||||
if credentials.is_none() {
|
||||
let mut status_rx = self.status();
|
||||
let _ = status_rx.next().await;
|
||||
futures::select_biased! {
|
||||
authenticate = self.authenticate(cx).fuse() => {
|
||||
match authenticate {
|
||||
Ok(creds) => credentials = Some(creds),
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
return ConnectionResult::Result(Err(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = status_rx.next().fuse() => {
|
||||
return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
|
||||
}
|
||||
}
|
||||
}
|
||||
let credentials = credentials.unwrap();
|
||||
self.set_id(credentials.user_id);
|
||||
self.cloud_client
|
||||
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
|
||||
let credentials = match self.sign_in(try_provider, cx).await {
|
||||
Ok(credentials) => credentials,
|
||||
Err(err) => return ConnectionResult::Result(Err(err)),
|
||||
};
|
||||
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Connecting, cx);
|
||||
|
@ -947,17 +1008,20 @@ impl Client {
|
|||
self.set_status(Status::Reconnecting, cx);
|
||||
}
|
||||
|
||||
self.connect_with_credentials(credentials, cx).await
|
||||
}
|
||||
|
||||
async fn connect_with_credentials(
|
||||
self: &Arc<Self>,
|
||||
credentials: Credentials,
|
||||
cx: &AsyncApp,
|
||||
) -> ConnectionResult<()> {
|
||||
let mut timeout =
|
||||
futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
|
||||
futures::select_biased! {
|
||||
connection = self.establish_connection(&credentials, cx).fuse() => {
|
||||
match connection {
|
||||
Ok(conn) => {
|
||||
self.state.write().credentials = Some(credentials.clone());
|
||||
if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
|
||||
self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
|
||||
}
|
||||
|
||||
futures::select_biased! {
|
||||
result = self.set_connection(conn, cx).fuse() => {
|
||||
match result.context("client auth and connect") {
|
||||
|
@ -975,15 +1039,8 @@ impl Client {
|
|||
}
|
||||
}
|
||||
Err(EstablishConnectionError::Unauthorized) => {
|
||||
self.state.write().credentials.take();
|
||||
if read_from_provider {
|
||||
self.credentials_provider.delete_credentials(cx).await.log_err();
|
||||
self.set_status(Status::SignedOut, cx);
|
||||
self.authenticate_and_connect(false, cx).await
|
||||
} else {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
|
||||
}
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
|
||||
}
|
||||
Err(EstablishConnectionError::UpgradeRequired) => {
|
||||
self.set_status(Status::UpgradeRequired, cx);
|
||||
|
@ -1733,7 +1790,7 @@ mod tests {
|
|||
});
|
||||
let auth_and_connect = cx.spawn({
|
||||
let client = client.clone();
|
||||
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
|cx| async move { client.connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert!(matches!(status.next().await, Some(Status::Connecting)));
|
||||
|
@ -1810,7 +1867,7 @@ mod tests {
|
|||
|
||||
let _authenticate = cx.spawn({
|
||||
let client = client.clone();
|
||||
move |cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
move |cx| async move { client.connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert_eq!(*auth_count.lock(), 1);
|
||||
|
@ -1818,7 +1875,7 @@ mod tests {
|
|||
|
||||
let _authenticate = cx.spawn({
|
||||
let client = client.clone();
|
||||
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
|cx| async move { client.connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert_eq!(*auth_count.lock(), 2);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue