From 7e4d5b7d04dba42f3096a2025e719e5881687db4 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Tue, 14 Sep 2021 20:36:03 -0600 Subject: [PATCH] Clear cached credentials when establishing a websocket connection with an invalid token --- gpui/src/platform.rs | 1 + gpui/src/platform/mac/platform.rs | 20 ++++++++++ gpui/src/platform/test.rs | 4 ++ server/src/auth.rs | 6 ++- zed/src/rpc.rs | 65 +++++++++++++++---------------- zed/src/test.rs | 2 +- zed/src/workspace.rs | 2 +- 7 files changed, 63 insertions(+), 37 deletions(-) diff --git a/gpui/src/platform.rs b/gpui/src/platform.rs index a4c86eab2f..cd972021a5 100644 --- a/gpui/src/platform.rs +++ b/gpui/src/platform.rs @@ -48,6 +48,7 @@ pub trait Platform: Send + Sync { fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Result<()>; fn read_credentials(&self, url: &str) -> Result)>>; + fn delete_credentials(&self, url: &str) -> Result<()>; fn set_cursor_style(&self, style: CursorStyle); diff --git a/gpui/src/platform/mac/platform.rs b/gpui/src/platform/mac/platform.rs index 7015cbc713..c956a19998 100644 --- a/gpui/src/platform/mac/platform.rs +++ b/gpui/src/platform/mac/platform.rs @@ -551,6 +551,25 @@ impl platform::Platform for MacPlatform { } } + fn delete_credentials(&self, url: &str) -> Result<()> { + let url = CFString::from(url); + + unsafe { + use security::*; + + let mut query_attrs = CFMutableDictionary::with_capacity(2); + query_attrs.set(kSecClass as *const _, kSecClassInternetPassword as *const _); + query_attrs.set(kSecAttrServer as *const _, url.as_CFTypeRef()); + + let status = SecItemDelete(query_attrs.as_concrete_TypeRef()); + + if status != errSecSuccess { + return Err(anyhow!("delete password failed: {}", status)); + } + } + Ok(()) + } + fn set_cursor_style(&self, style: CursorStyle) { unsafe { let cursor: id = match style { @@ -676,6 +695,7 @@ mod security { pub fn SecItemAdd(attributes: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus; pub fn SecItemUpdate(query: CFDictionaryRef, attributes: CFDictionaryRef) -> OSStatus; + pub fn SecItemDelete(query: CFDictionaryRef) -> OSStatus; pub fn SecItemCopyMatching(query: CFDictionaryRef, result: *mut CFTypeRef) -> OSStatus; } diff --git a/gpui/src/platform/test.rs b/gpui/src/platform/test.rs index 85afff4999..d705a277e5 100644 --- a/gpui/src/platform/test.rs +++ b/gpui/src/platform/test.rs @@ -137,6 +137,10 @@ impl super::Platform for Platform { Ok(None) } + fn delete_credentials(&self, _: &str) -> Result<()> { + Ok(()) + } + fn set_cursor_style(&self, style: CursorStyle) { *self.cursor.lock() = style; } diff --git a/server/src/auth.rs b/server/src/auth.rs index 5a3e301d27..1f6ec5f1db 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -17,7 +17,7 @@ use scrypt::{ }; use serde::{Deserialize, Serialize}; use std::{borrow::Cow, convert::TryFrom, sync::Arc}; -use surf::Url; +use surf::{StatusCode, Url}; use tide::Server; use zrpc::auth as zed_auth; @@ -73,7 +73,9 @@ impl tide::Middleware> for VerifyToken { request.set_ext(user_id); Ok(next.run(request).await) } else { - Err(anyhow!("invalid credentials").into()) + let mut response = tide::Response::new(StatusCode::Unauthorized); + response.set_body("invalid credentials"); + Ok(response) } } } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 3526381cde..9596b671ed 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -1,6 +1,9 @@ use crate::util::ResultExt; use anyhow::{anyhow, Context, Result}; -use async_tungstenite::tungstenite::http::Request; +use async_tungstenite::tungstenite::{ + error::Error as WebsocketError, + http::{Request, StatusCode}, +}; use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use lazy_static::lazy_static; use parking_lot::RwLock; @@ -47,10 +50,25 @@ pub struct Client { #[derive(Error, Debug)] pub enum EstablishConnectionError { - #[error("invalid access token")] - InvalidAccessToken, + #[error("unauthorized")] + Unauthorized, #[error("{0}")] - Other(anyhow::Error), + Other(#[from] anyhow::Error), + #[error("{0}")] + Io(#[from] std::io::Error), + #[error("{0}")] + Http(#[from] async_tungstenite::tungstenite::http::Error), +} + +impl From for EstablishConnectionError { + fn from(error: WebsocketError) -> Self { + if let WebsocketError::Http(response) = &error { + if response.status() == StatusCode::UNAUTHORIZED { + return EstablishConnectionError::Unauthorized; + } + } + EstablishConnectionError::Other(error.into()) + } } impl EstablishConnectionError { @@ -314,10 +332,9 @@ impl Client { Ok(()) } Err(err) => { - eprintln!("error in authenticate and connect {}", err); - if matches!(err, EstablishConnectionError::InvalidAccessToken) { - eprintln!("nuking credentials"); + if matches!(err, EstablishConnectionError::Unauthorized) { self.state.write().credentials.take(); + cx.platform().delete_credentials(&ZED_SERVER_URL).ok(); } self.set_status(Status::ConnectionError, cx); Err(err)? @@ -409,36 +426,18 @@ impl Client { ); cx.background().spawn(async move { if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") { - let stream = smol::net::TcpStream::connect(host) - .await - .map_err(EstablishConnectionError::other)?; - let request = request - .uri(format!("wss://{}/rpc", host)) - .body(()) - .map_err(EstablishConnectionError::other)?; - let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream) - .await - .context("websocket handshake") - .map_err(EstablishConnectionError::other)?; + let stream = smol::net::TcpStream::connect(host).await?; + let request = request.uri(format!("wss://{}/rpc", host)).body(())?; + let (stream, _) = + async_tungstenite::async_tls::client_async_tls(request, stream).await?; Ok(Connection::new(stream)) } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") { - let stream = smol::net::TcpStream::connect(host) - .await - .map_err(EstablishConnectionError::other)?; - let request = request - .uri(format!("ws://{}/rpc", host)) - .body(()) - .map_err(EstablishConnectionError::other)?; - let (stream, _) = async_tungstenite::client_async(request, stream) - .await - .context("websocket handshake") - .map_err(EstablishConnectionError::other)?; + let stream = smol::net::TcpStream::connect(host).await?; + let request = request.uri(format!("ws://{}/rpc", host)).body(())?; + let (stream, _) = async_tungstenite::client_async(request, stream).await?; Ok(Connection::new(stream)) } else { - Err(EstablishConnectionError::other(anyhow!( - "invalid server url: {}", - *ZED_SERVER_URL - ))) + Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))? } }) } diff --git a/zed/src/test.rs b/zed/src/test.rs index e5ab3154f5..7d027a8a17 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -283,7 +283,7 @@ impl FakeServer { } if credentials.access_token != self.access_token.load(SeqCst).to_string() { - Err(EstablishConnectionError::InvalidAccessToken)? + Err(EstablishConnectionError::Unauthorized)? } let (client_conn, server_conn, _) = Connection::in_memory(); diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index cbdf3b149a..9ce67c2f8a 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -960,7 +960,7 @@ impl Workspace { fn render_connection_status(&self) -> Option { let theme = &self.settings.borrow().theme; - match dbg!(&*self.rpc.status().borrow()) { + match &*self.rpc.status().borrow() { rpc::Status::ConnectionError | rpc::Status::ConnectionLost | rpc::Status::Reauthenticating