Refactor add_request_handler to respond via a Response struct

This also removes `add_sync_request_handler`.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2022-05-06 17:01:27 +02:00
parent 9555b93bca
commit 989b82d664

View file

@ -18,7 +18,7 @@ use axum::{
headers::{Header, HeaderName}, headers::{Header, HeaderName},
http::StatusCode, http::StatusCode,
middleware, middleware,
response::{IntoResponse, Response}, response::IntoResponse,
routing::get, routing::get,
Extension, Router, TypedHeader, Extension, Router, TypedHeader,
}; };
@ -27,7 +27,7 @@ use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt, T
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rpc::{ use rpc::{
proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}, proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
Connection, ConnectionId, Peer, TypedEnvelope, Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
}; };
use std::{ use std::{
any::TypeId, any::TypeId,
@ -36,7 +36,10 @@ use std::{
net::SocketAddr, net::SocketAddr,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
rc::Rc, rc::Rc,
sync::Arc, sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
time::Duration, time::Duration,
}; };
use store::{Store, Worktree}; use store::{Store, Worktree};
@ -51,6 +54,20 @@ use tracing::{info_span, instrument, Instrument};
type MessageHandler = type MessageHandler =
Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>; Box<dyn Send + Sync + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, ()>>;
struct Response<R> {
server: Arc<Server>,
receipt: Receipt<R>,
responded: Arc<AtomicBool>,
}
impl<R: RequestMessage> Response<R> {
fn send(self, payload: R::Response) -> Result<()> {
self.responded.store(true, SeqCst);
self.server.peer.respond(self.receipt, payload)?;
Ok(())
}
}
pub struct Server { pub struct Server {
peer: Arc<Peer>, peer: Arc<Peer>,
store: RwLock<Store>, store: RwLock<Store>,
@ -100,7 +117,7 @@ impl Server {
.add_message_handler(Server::unregister_project) .add_message_handler(Server::unregister_project)
.add_request_handler(Server::share_project) .add_request_handler(Server::share_project)
.add_message_handler(Server::unshare_project) .add_message_handler(Server::unshare_project)
.add_sync_request_handler(Server::join_project) .add_request_handler(Server::join_project)
.add_message_handler(Server::leave_project) .add_message_handler(Server::leave_project)
.add_request_handler(Server::register_worktree) .add_request_handler(Server::register_worktree)
.add_message_handler(Server::unregister_worktree) .add_message_handler(Server::unregister_worktree)
@ -179,43 +196,12 @@ impl Server {
self self
} }
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
Fut: 'static + Send + Future<Output = Result<M::Response>>,
M: RequestMessage,
{
self.add_message_handler(move |server, envelope| {
let receipt = envelope.receipt();
let response = (handler)(server.clone(), envelope);
async move {
match response.await {
Ok(response) => {
server.peer.respond(receipt, response)?;
Ok(())
}
Err(error) => {
server.peer.respond_with_error(
receipt,
proto::Error {
message: error.to_string(),
},
)?;
Err(error)
}
}
}
})
}
/// Handle a request while holding a lock to the store. This is useful when we're registering /// Handle a request while holding a lock to the store. This is useful when we're registering
/// a connection but we want to respond on the connection before anybody else can send on it. /// a connection but we want to respond on the connection before anybody else can send on it.
fn add_sync_request_handler<F, M>(&mut self, handler: F) -> &mut Self fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where where
F: 'static F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>, Response<M>) -> Fut,
+ Send Fut: Send + Future<Output = Result<()>>,
+ Sync
+ Fn(Arc<Self>, &mut Store, TypedEnvelope<M>) -> Result<M::Response>,
M: RequestMessage, M: RequestMessage,
{ {
let handler = Arc::new(handler); let handler = Arc::new(handler);
@ -223,12 +209,19 @@ impl Server {
let receipt = envelope.receipt(); let receipt = envelope.receipt();
let handler = handler.clone(); let handler = handler.clone();
async move { async move {
let mut store = server.state_mut().await; let responded = Arc::new(AtomicBool::default());
let response = (handler)(server.clone(), &mut *store, envelope); let response = Response {
match response { server: server.clone(),
Ok(response) => { responded: responded.clone(),
server.peer.respond(receipt, response)?; receipt: envelope.receipt(),
Ok(()) };
match (handler)(server.clone(), envelope, response).await {
Ok(()) => {
if responded.load(std::sync::atomic::Ordering::SeqCst) {
Ok(())
} else {
Err(anyhow!("handler did not send a response"))?
}
} }
Err(error) => { Err(error) => {
server.peer.respond_with_error( server.peer.respond_with_error(
@ -364,20 +357,27 @@ impl Server {
Ok(()) Ok(())
} }
async fn ping(self: Arc<Server>, _: TypedEnvelope<proto::Ping>) -> Result<proto::Ack> { async fn ping(
Ok(proto::Ack {}) self: Arc<Server>,
_: TypedEnvelope<proto::Ping>,
response: Response<proto::Ping>,
) -> Result<()> {
response.send(proto::Ack {})?;
Ok(())
} }
async fn register_project( async fn register_project(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::RegisterProject>, request: TypedEnvelope<proto::RegisterProject>,
) -> Result<proto::RegisterProjectResponse> { response: Response<proto::RegisterProject>,
) -> Result<()> {
let project_id = { let project_id = {
let mut state = self.state_mut().await; let mut state = self.state_mut().await;
let user_id = state.user_id_for_connection(request.sender_id)?; let user_id = state.user_id_for_connection(request.sender_id)?;
state.register_project(request.sender_id, user_id) state.register_project(request.sender_id, user_id)
}; };
Ok(proto::RegisterProjectResponse { project_id }) response.send(proto::RegisterProjectResponse { project_id })?;
Ok(())
} }
async fn unregister_project( async fn unregister_project(
@ -393,11 +393,13 @@ impl Server {
async fn share_project( async fn share_project(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::ShareProject>, request: TypedEnvelope<proto::ShareProject>,
) -> Result<proto::Ack> { response: Response<proto::ShareProject>,
) -> Result<()> {
let mut state = self.state_mut().await; let mut state = self.state_mut().await;
let project = state.share_project(request.payload.project_id, request.sender_id)?; let project = state.share_project(request.payload.project_id, request.sender_id)?;
self.update_contacts_for_users(&mut *state, &project.authorized_user_ids); self.update_contacts_for_users(&mut *state, &project.authorized_user_ids);
Ok(proto::Ack {}) response.send(proto::Ack {})?;
Ok(())
} }
async fn unshare_project( async fn unshare_project(
@ -415,15 +417,16 @@ impl Server {
Ok(()) Ok(())
} }
fn join_project( async fn join_project(
self: Arc<Server>, self: Arc<Server>,
state: &mut Store,
request: TypedEnvelope<proto::JoinProject>, request: TypedEnvelope<proto::JoinProject>,
) -> Result<proto::JoinProjectResponse> { response: Response<proto::JoinProject>,
) -> Result<()> {
let project_id = request.payload.project_id; let project_id = request.payload.project_id;
let state = &mut *self.state_mut().await;
let user_id = state.user_id_for_connection(request.sender_id)?; let user_id = state.user_id_for_connection(request.sender_id)?;
let (response, connection_ids, contact_user_ids) = state let (response_payload, connection_ids, contact_user_ids) = state
.join_project(request.sender_id, user_id, project_id) .join_project(request.sender_id, user_id, project_id)
.and_then(|joined| { .and_then(|joined| {
let share = joined.project.share()?; let share = joined.project.share()?;
@ -480,14 +483,15 @@ impl Server {
project_id, project_id,
collaborator: Some(proto::Collaborator { collaborator: Some(proto::Collaborator {
peer_id: request.sender_id.0, peer_id: request.sender_id.0,
replica_id: response.replica_id, replica_id: response_payload.replica_id,
user_id: user_id.to_proto(), user_id: user_id.to_proto(),
}), }),
}, },
) )
}); });
self.update_contacts_for_users(state, &contact_user_ids); self.update_contacts_for_users(state, &contact_user_ids);
Ok(response) response.send(response_payload)?;
Ok(())
} }
async fn leave_project( async fn leave_project(
@ -514,7 +518,8 @@ impl Server {
async fn register_worktree( async fn register_worktree(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::RegisterWorktree>, request: TypedEnvelope<proto::RegisterWorktree>,
) -> Result<proto::Ack> { response: Response<proto::RegisterWorktree>,
) -> Result<()> {
let mut contact_user_ids = HashSet::default(); let mut contact_user_ids = HashSet::default();
for github_login in &request.payload.authorized_logins { for github_login in &request.payload.authorized_logins {
let contact_user_id = self.app_state.db.create_user(github_login, false).await?; let contact_user_id = self.app_state.db.create_user(github_login, false).await?;
@ -545,7 +550,8 @@ impl Server {
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}); });
self.update_contacts_for_users(&*state, &contact_user_ids); self.update_contacts_for_users(&*state, &contact_user_ids);
Ok(proto::Ack {}) response.send(proto::Ack {})?;
Ok(())
} }
async fn unregister_worktree( async fn unregister_worktree(
@ -573,7 +579,8 @@ impl Server {
async fn update_worktree( async fn update_worktree(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>, request: TypedEnvelope<proto::UpdateWorktree>,
) -> Result<proto::Ack> { response: Response<proto::UpdateWorktree>,
) -> Result<()> {
let connection_ids = self.state_mut().await.update_worktree( let connection_ids = self.state_mut().await.update_worktree(
request.sender_id, request.sender_id,
request.payload.project_id, request.payload.project_id,
@ -587,8 +594,8 @@ impl Server {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}); });
response.send(proto::Ack {})?;
Ok(proto::Ack {}) Ok(())
} }
async fn update_diagnostic_summary( async fn update_diagnostic_summary(
@ -652,7 +659,8 @@ impl Server {
async fn forward_project_request<T>( async fn forward_project_request<T>(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<T>, request: TypedEnvelope<T>,
) -> Result<T::Response> response: Response<T>,
) -> Result<()>
where where
T: EntityMessage + RequestMessage, T: EntityMessage + RequestMessage,
{ {
@ -661,22 +669,26 @@ impl Server {
.await .await
.read_project(request.payload.remote_entity_id(), request.sender_id)? .read_project(request.payload.remote_entity_id(), request.sender_id)?
.host_connection_id; .host_connection_id;
Ok(self
.peer response.send(
.forward_request(request.sender_id, host_connection_id, request.payload) self.peer
.await?) .forward_request(request.sender_id, host_connection_id, request.payload)
.await?,
)?;
Ok(())
} }
async fn save_buffer( async fn save_buffer(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::SaveBuffer>, request: TypedEnvelope<proto::SaveBuffer>,
) -> Result<proto::BufferSaved> { response: Response<proto::SaveBuffer>,
) -> Result<()> {
let host = self let host = self
.state() .state()
.await .await
.read_project(request.payload.project_id, request.sender_id)? .read_project(request.payload.project_id, request.sender_id)?
.host_connection_id; .host_connection_id;
let response = self let response_payload = self
.peer .peer
.forward_request(request.sender_id, host, request.payload.clone()) .forward_request(request.sender_id, host, request.payload.clone())
.await?; .await?;
@ -688,16 +700,18 @@ impl Server {
.connection_ids(); .connection_ids();
guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id); guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
broadcast(host, guests, |conn_id| { broadcast(host, guests, |conn_id| {
self.peer.forward_send(host, conn_id, response.clone()) self.peer
.forward_send(host, conn_id, response_payload.clone())
}); });
response.send(response_payload)?;
Ok(response) Ok(())
} }
async fn update_buffer( async fn update_buffer(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::UpdateBuffer>, request: TypedEnvelope<proto::UpdateBuffer>,
) -> Result<proto::Ack> { response: Response<proto::UpdateBuffer>,
) -> Result<()> {
let receiver_ids = self let receiver_ids = self
.state() .state()
.await .await
@ -706,7 +720,8 @@ impl Server {
self.peer self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone()) .forward_send(request.sender_id, connection_id, request.payload.clone())
}); });
Ok(proto::Ack {}) response.send(proto::Ack {})?;
Ok(())
} }
async fn update_buffer_file( async fn update_buffer_file(
@ -757,7 +772,8 @@ impl Server {
async fn follow( async fn follow(
self: Arc<Self>, self: Arc<Self>,
request: TypedEnvelope<proto::Follow>, request: TypedEnvelope<proto::Follow>,
) -> Result<proto::FollowResponse> { response: Response<proto::Follow>,
) -> Result<()> {
let leader_id = ConnectionId(request.payload.leader_id); let leader_id = ConnectionId(request.payload.leader_id);
let follower_id = request.sender_id; let follower_id = request.sender_id;
if !self if !self
@ -768,14 +784,15 @@ impl Server {
{ {
Err(anyhow!("no such peer"))?; Err(anyhow!("no such peer"))?;
} }
let mut response = self let mut response_payload = self
.peer .peer
.forward_request(request.sender_id, leader_id, request.payload) .forward_request(request.sender_id, leader_id, request.payload)
.await?; .await?;
response response_payload
.views .views
.retain(|view| view.leader_id != Some(follower_id.0)); .retain(|view| view.leader_id != Some(follower_id.0));
Ok(response) response.send(response_payload)?;
Ok(())
} }
async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> { async fn unfollow(self: Arc<Self>, request: TypedEnvelope<proto::Unfollow>) -> Result<()> {
@ -823,13 +840,14 @@ impl Server {
async fn get_channels( async fn get_channels(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::GetChannels>, request: TypedEnvelope<proto::GetChannels>,
) -> Result<proto::GetChannelsResponse> { response: Response<proto::GetChannels>,
) -> Result<()> {
let user_id = self let user_id = self
.state() .state()
.await .await
.user_id_for_connection(request.sender_id)?; .user_id_for_connection(request.sender_id)?;
let channels = self.app_state.db.get_accessible_channels(user_id).await?; let channels = self.app_state.db.get_accessible_channels(user_id).await?;
Ok(proto::GetChannelsResponse { response.send(proto::GetChannelsResponse {
channels: channels channels: channels
.into_iter() .into_iter()
.map(|chan| proto::Channel { .map(|chan| proto::Channel {
@ -837,13 +855,15 @@ impl Server {
name: chan.name, name: chan.name,
}) })
.collect(), .collect(),
}) })?;
Ok(())
} }
async fn get_users( async fn get_users(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::GetUsers>, request: TypedEnvelope<proto::GetUsers>,
) -> Result<proto::UsersResponse> { response: Response<proto::GetUsers>,
) -> Result<()> {
let user_ids = request let user_ids = request
.payload .payload
.user_ids .user_ids
@ -862,13 +882,15 @@ impl Server {
github_login: user.github_login, github_login: user.github_login,
}) })
.collect(); .collect();
Ok(proto::UsersResponse { users }) response.send(proto::UsersResponse { users })?;
Ok(())
} }
async fn fuzzy_search_users( async fn fuzzy_search_users(
self: Arc<Server>, self: Arc<Server>,
request: TypedEnvelope<proto::FuzzySearchUsers>, request: TypedEnvelope<proto::FuzzySearchUsers>,
) -> Result<proto::UsersResponse> { response: Response<proto::FuzzySearchUsers>,
) -> Result<()> {
let query = request.payload.query; let query = request.payload.query;
let db = &self.app_state.db; let db = &self.app_state.db;
let users = match query.len() { let users = match query.len() {
@ -888,7 +910,8 @@ impl Server {
github_login: user.github_login, github_login: user.github_login,
}) })
.collect(); .collect();
Ok(proto::UsersResponse { users }) response.send(proto::UsersResponse { users })?;
Ok(())
} }
#[instrument(skip(self, state, user_ids))] #[instrument(skip(self, state, user_ids))]
@ -917,7 +940,8 @@ impl Server {
async fn join_channel( async fn join_channel(
self: Arc<Self>, self: Arc<Self>,
request: TypedEnvelope<proto::JoinChannel>, request: TypedEnvelope<proto::JoinChannel>,
) -> Result<proto::JoinChannelResponse> { response: Response<proto::JoinChannel>,
) -> Result<()> {
let user_id = self let user_id = self
.state() .state()
.await .await
@ -949,10 +973,11 @@ impl Server {
nonce: Some(msg.nonce.as_u128().into()), nonce: Some(msg.nonce.as_u128().into()),
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Ok(proto::JoinChannelResponse { response.send(proto::JoinChannelResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE, done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages, messages,
}) })?;
Ok(())
} }
async fn leave_channel( async fn leave_channel(
@ -983,7 +1008,8 @@ impl Server {
async fn send_channel_message( async fn send_channel_message(
self: Arc<Self>, self: Arc<Self>,
request: TypedEnvelope<proto::SendChannelMessage>, request: TypedEnvelope<proto::SendChannelMessage>,
) -> Result<proto::SendChannelMessageResponse> { response: Response<proto::SendChannelMessage>,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.payload.channel_id); let channel_id = ChannelId::from_proto(request.payload.channel_id);
let user_id; let user_id;
let connection_ids; let connection_ids;
@ -1030,15 +1056,17 @@ impl Server {
}, },
) )
}); });
Ok(proto::SendChannelMessageResponse { response.send(proto::SendChannelMessageResponse {
message: Some(message), message: Some(message),
}) })?;
Ok(())
} }
async fn get_channel_messages( async fn get_channel_messages(
self: Arc<Self>, self: Arc<Self>,
request: TypedEnvelope<proto::GetChannelMessages>, request: TypedEnvelope<proto::GetChannelMessages>,
) -> Result<proto::GetChannelMessagesResponse> { response: Response<proto::GetChannelMessages>,
) -> Result<()> {
let user_id = self let user_id = self
.state() .state()
.await .await
@ -1071,11 +1099,11 @@ impl Server {
nonce: Some(msg.nonce.as_u128().into()), nonce: Some(msg.nonce.as_u128().into()),
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
response.send(proto::GetChannelMessagesResponse {
Ok(proto::GetChannelMessagesResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE, done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages, messages,
}) })?;
Ok(())
} }
async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> { async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
@ -1213,7 +1241,7 @@ pub async fn handle_websocket_request(
Extension(server): Extension<Arc<Server>>, Extension(server): Extension<Arc<Server>>,
Extension(user_id): Extension<UserId>, Extension(user_id): Extension<UserId>,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
) -> Response { ) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION { if protocol_version != rpc::PROTOCOL_VERSION {
return ( return (
StatusCode::UPGRADE_REQUIRED, StatusCode::UPGRADE_REQUIRED,