Clear cached credentials when establishing a websocket connection with an invalid token

This commit is contained in:
Nathan Sobo 2021-09-14 20:36:03 -06:00
parent 4a9918979e
commit 7e4d5b7d04
7 changed files with 63 additions and 37 deletions

View file

@ -48,6 +48,7 @@ pub trait Platform: Send + Sync {
fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Result<()>; fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Result<()>;
fn read_credentials(&self, url: &str) -> Result<Option<(String, Vec<u8>)>>; fn read_credentials(&self, url: &str) -> Result<Option<(String, Vec<u8>)>>;
fn delete_credentials(&self, url: &str) -> Result<()>;
fn set_cursor_style(&self, style: CursorStyle); fn set_cursor_style(&self, style: CursorStyle);

View file

@ -551,6 +551,25 @@ impl platform::Platform for MacPlatform {
} }
} }
fn delete_credentials(&self, url: &str) -> Result<()> {
let url = CFString::from(url);
unsafe {
use security::*;
let mut query_attrs = CFMutableDictionary::with_capacity(2);
query_attrs.set(kSecClass as *const _, kSecClassInternetPassword as *const _);
query_attrs.set(kSecAttrServer as *const _, url.as_CFTypeRef());
let status = SecItemDelete(query_attrs.as_concrete_TypeRef());
if status != errSecSuccess {
return Err(anyhow!("delete password failed: {}", status));
}
}
Ok(())
}
fn set_cursor_style(&self, style: CursorStyle) { fn set_cursor_style(&self, style: CursorStyle) {
unsafe { unsafe {
let cursor: id = match style { let cursor: id = match style {
@ -676,6 +695,7 @@ mod security {
pub fn SecItemAdd(attributes: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus; pub fn SecItemAdd(attributes: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus;
pub fn SecItemUpdate(query: CFDictionaryRef, attributes: CFDictionaryRef) -> OSStatus; pub fn SecItemUpdate(query: CFDictionaryRef, attributes: CFDictionaryRef) -> OSStatus;
pub fn SecItemDelete(query: CFDictionaryRef) -> OSStatus;
pub fn SecItemCopyMatching(query: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus; pub fn SecItemCopyMatching(query: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus;
} }

View file

@ -137,6 +137,10 @@ impl super::Platform for Platform {
Ok(None) Ok(None)
} }
fn delete_credentials(&self, _: &str) -> Result<()> {
Ok(())
}
fn set_cursor_style(&self, style: CursorStyle) { fn set_cursor_style(&self, style: CursorStyle) {
*self.cursor.lock() = style; *self.cursor.lock() = style;
} }

View file

@ -17,7 +17,7 @@ use scrypt::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{borrow::Cow, convert::TryFrom, sync::Arc}; use std::{borrow::Cow, convert::TryFrom, sync::Arc};
use surf::Url; use surf::{StatusCode, Url};
use tide::Server; use tide::Server;
use zrpc::auth as zed_auth; use zrpc::auth as zed_auth;
@ -73,7 +73,9 @@ impl tide::Middleware<Arc<AppState>> for VerifyToken {
request.set_ext(user_id); request.set_ext(user_id);
Ok(next.run(request).await) Ok(next.run(request).await)
} else { } else {
Err(anyhow!("invalid credentials").into()) let mut response = tide::Response::new(StatusCode::Unauthorized);
response.set_body("invalid credentials");
Ok(response)
} }
} }
} }

View file

@ -1,6 +1,9 @@
use crate::util::ResultExt; use crate::util::ResultExt;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_tungstenite::tungstenite::http::Request; use async_tungstenite::tungstenite::{
error::Error as WebsocketError,
http::{Request, StatusCode},
};
use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use gpui::{AsyncAppContext, Entity, ModelContext, Task};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use parking_lot::RwLock; use parking_lot::RwLock;
@ -47,10 +50,25 @@ pub struct Client {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum EstablishConnectionError { pub enum EstablishConnectionError {
#[error("invalid access token")] #[error("unauthorized")]
InvalidAccessToken, Unauthorized,
#[error("{0}")] #[error("{0}")]
Other(anyhow::Error), Other(#[from] anyhow::Error),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
Http(#[from] async_tungstenite::tungstenite::http::Error),
}
impl From<WebsocketError> for EstablishConnectionError {
fn from(error: WebsocketError) -> Self {
if let WebsocketError::Http(response) = &error {
if response.status() == StatusCode::UNAUTHORIZED {
return EstablishConnectionError::Unauthorized;
}
}
EstablishConnectionError::Other(error.into())
}
} }
impl EstablishConnectionError { impl EstablishConnectionError {
@ -314,10 +332,9 @@ impl Client {
Ok(()) Ok(())
} }
Err(err) => { Err(err) => {
eprintln!("error in authenticate and connect {}", err); if matches!(err, EstablishConnectionError::Unauthorized) {
if matches!(err, EstablishConnectionError::InvalidAccessToken) {
eprintln!("nuking credentials");
self.state.write().credentials.take(); self.state.write().credentials.take();
cx.platform().delete_credentials(&ZED_SERVER_URL).ok();
} }
self.set_status(Status::ConnectionError, cx); self.set_status(Status::ConnectionError, cx);
Err(err)? Err(err)?
@ -409,36 +426,18 @@ impl Client {
); );
cx.background().spawn(async move { cx.background().spawn(async move {
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") { if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host) let stream = smol::net::TcpStream::connect(host).await?;
.await let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
.map_err(EstablishConnectionError::other)?; let (stream, _) =
let request = request async_tungstenite::async_tls::client_async_tls(request, stream).await?;
.uri(format!("wss://{}/rpc", host))
.body(())
.map_err(EstablishConnectionError::other)?;
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
.await
.context("websocket handshake")
.map_err(EstablishConnectionError::other)?;
Ok(Connection::new(stream)) Ok(Connection::new(stream))
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host) let stream = smol::net::TcpStream::connect(host).await?;
.await let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
.map_err(EstablishConnectionError::other)?; let (stream, _) = async_tungstenite::client_async(request, stream).await?;
let request = request
.uri(format!("ws://{}/rpc", host))
.body(())
.map_err(EstablishConnectionError::other)?;
let (stream, _) = async_tungstenite::client_async(request, stream)
.await
.context("websocket handshake")
.map_err(EstablishConnectionError::other)?;
Ok(Connection::new(stream)) Ok(Connection::new(stream))
} else { } else {
Err(EstablishConnectionError::other(anyhow!( Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?
"invalid server url: {}",
*ZED_SERVER_URL
)))
} }
}) })
} }

View file

@ -283,7 +283,7 @@ impl FakeServer {
} }
if credentials.access_token != self.access_token.load(SeqCst).to_string() { if credentials.access_token != self.access_token.load(SeqCst).to_string() {
Err(EstablishConnectionError::InvalidAccessToken)? Err(EstablishConnectionError::Unauthorized)?
} }
let (client_conn, server_conn, _) = Connection::in_memory(); let (client_conn, server_conn, _) = Connection::in_memory();

View file

@ -960,7 +960,7 @@ impl Workspace {
fn render_connection_status(&self) -> Option<ElementBox> { fn render_connection_status(&self) -> Option<ElementBox> {
let theme = &self.settings.borrow().theme; let theme = &self.settings.borrow().theme;
match dbg!(&*self.rpc.status().borrow()) { match &*self.rpc.status().borrow() {
rpc::Status::ConnectionError rpc::Status::ConnectionError
| rpc::Status::ConnectionLost | rpc::Status::ConnectionLost
| rpc::Status::Reauthenticating | rpc::Status::Reauthenticating