Clear cached credentials when establishing a websocket connection with an invalid token

This commit is contained in:
Nathan Sobo 2021-09-14 20:36:03 -06:00
parent 4a9918979e
commit 7e4d5b7d04
7 changed files with 63 additions and 37 deletions

View file

@ -48,6 +48,7 @@ pub trait Platform: Send + Sync {
fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Result<()>;
fn read_credentials(&self, url: &str) -> Result<Option<(String, Vec<u8>)>>;
fn delete_credentials(&self, url: &str) -> Result<()>;
fn set_cursor_style(&self, style: CursorStyle);

View file

@ -551,6 +551,25 @@ impl platform::Platform for MacPlatform {
}
}
fn delete_credentials(&self, url: &str) -> Result<()> {
let url = CFString::from(url);
unsafe {
use security::*;
let mut query_attrs = CFMutableDictionary::with_capacity(2);
query_attrs.set(kSecClass as *const _, kSecClassInternetPassword as *const _);
query_attrs.set(kSecAttrServer as *const _, url.as_CFTypeRef());
let status = SecItemDelete(query_attrs.as_concrete_TypeRef());
if status != errSecSuccess {
return Err(anyhow!("delete password failed: {}", status));
}
}
Ok(())
}
fn set_cursor_style(&self, style: CursorStyle) {
unsafe {
let cursor: id = match style {
@ -676,6 +695,7 @@ mod security {
pub fn SecItemAdd(attributes: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus;
pub fn SecItemUpdate(query: CFDictionaryRef, attributes: CFDictionaryRef) -> OSStatus;
pub fn SecItemDelete(query: CFDictionaryRef) -> OSStatus;
pub fn SecItemCopyMatching(query: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus;
}

View file

@ -137,6 +137,10 @@ impl super::Platform for Platform {
Ok(None)
}
fn delete_credentials(&self, _: &str) -> Result<()> {
Ok(())
}
fn set_cursor_style(&self, style: CursorStyle) {
*self.cursor.lock() = style;
}

View file

@ -17,7 +17,7 @@ use scrypt::{
};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, convert::TryFrom, sync::Arc};
use surf::Url;
use surf::{StatusCode, Url};
use tide::Server;
use zrpc::auth as zed_auth;
@ -73,7 +73,9 @@ impl tide::Middleware<Arc<AppState>> for VerifyToken {
request.set_ext(user_id);
Ok(next.run(request).await)
} else {
Err(anyhow!("invalid credentials").into())
let mut response = tide::Response::new(StatusCode::Unauthorized);
response.set_body("invalid credentials");
Ok(response)
}
}
}

View file

@ -1,6 +1,9 @@
use crate::util::ResultExt;
use anyhow::{anyhow, Context, Result};
use async_tungstenite::tungstenite::http::Request;
use async_tungstenite::tungstenite::{
error::Error as WebsocketError,
http::{Request, StatusCode},
};
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
use lazy_static::lazy_static;
use parking_lot::RwLock;
@ -47,10 +50,25 @@ pub struct Client {
#[derive(Error, Debug)]
pub enum EstablishConnectionError {
#[error("invalid access token")]
InvalidAccessToken,
#[error("unauthorized")]
Unauthorized,
#[error("{0}")]
Other(anyhow::Error),
Other(#[from] anyhow::Error),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
Http(#[from] async_tungstenite::tungstenite::http::Error),
}
impl From<WebsocketError> for EstablishConnectionError {
fn from(error: WebsocketError) -> Self {
if let WebsocketError::Http(response) = &error {
if response.status() == StatusCode::UNAUTHORIZED {
return EstablishConnectionError::Unauthorized;
}
}
EstablishConnectionError::Other(error.into())
}
}
impl EstablishConnectionError {
@ -314,10 +332,9 @@ impl Client {
Ok(())
}
Err(err) => {
eprintln!("error in authenticate and connect {}", err);
if matches!(err, EstablishConnectionError::InvalidAccessToken) {
eprintln!("nuking credentials");
if matches!(err, EstablishConnectionError::Unauthorized) {
self.state.write().credentials.take();
cx.platform().delete_credentials(&ZED_SERVER_URL).ok();
}
self.set_status(Status::ConnectionError, cx);
Err(err)?
@ -409,36 +426,18 @@ impl Client {
);
cx.background().spawn(async move {
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host)
.await
.map_err(EstablishConnectionError::other)?;
let request = request
.uri(format!("wss://{}/rpc", host))
.body(())
.map_err(EstablishConnectionError::other)?;
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
.await
.context("websocket handshake")
.map_err(EstablishConnectionError::other)?;
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
let (stream, _) =
async_tungstenite::async_tls::client_async_tls(request, stream).await?;
Ok(Connection::new(stream))
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host)
.await
.map_err(EstablishConnectionError::other)?;
let request = request
.uri(format!("ws://{}/rpc", host))
.body(())
.map_err(EstablishConnectionError::other)?;
let (stream, _) = async_tungstenite::client_async(request, stream)
.await
.context("websocket handshake")
.map_err(EstablishConnectionError::other)?;
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::client_async(request, stream).await?;
Ok(Connection::new(stream))
} else {
Err(EstablishConnectionError::other(anyhow!(
"invalid server url: {}",
*ZED_SERVER_URL
)))
Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?
}
})
}

View file

@ -283,7 +283,7 @@ impl FakeServer {
}
if credentials.access_token != self.access_token.load(SeqCst).to_string() {
Err(EstablishConnectionError::InvalidAccessToken)?
Err(EstablishConnectionError::Unauthorized)?
}
let (client_conn, server_conn, _) = Connection::in_memory();

View file

@ -960,7 +960,7 @@ impl Workspace {
fn render_connection_status(&self) -> Option<ElementBox> {
let theme = &self.settings.borrow().theme;
match dbg!(&*self.rpc.status().borrow()) {
match &*self.rpc.status().borrow() {
rpc::Status::ConnectionError
| rpc::Status::ConnectionLost
| rpc::Status::Reauthenticating