Merge commit '680b86b17c' into main

This commit is contained in:
Max Brunsfeld 2021-09-10 15:22:59 -07:00
commit 3d4a451c15
14 changed files with 1045 additions and 439 deletions

View file

@ -3,8 +3,9 @@ use async_task::Runnable;
pub use async_task::Task;
use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
use parking_lot::Mutex;
use postage::{barrier, prelude::Stream as _};
use rand::prelude::*;
use smol::{channel, prelude::*, Executor};
use smol::{channel, prelude::*, Executor, Timer};
use std::{
fmt::{self, Debug},
marker::PhantomData,
@ -18,7 +19,7 @@ use std::{
},
task::{Context, Poll},
thread,
time::Duration,
time::{Duration, Instant},
};
use waker_fn::waker_fn;
@ -49,6 +50,8 @@ struct DeterministicState {
spawned_from_foreground: Vec<(Runnable, Backtrace)>,
forbid_parking: bool,
block_on_ticks: RangeInclusive<usize>,
now: Instant,
pending_timers: Vec<(Instant, barrier::Sender)>,
}
pub struct Deterministic {
@ -67,6 +70,8 @@ impl Deterministic {
spawned_from_foreground: Default::default(),
forbid_parking: false,
block_on_ticks: 0..=1000,
now: Instant::now(),
pending_timers: Default::default(),
})),
parker: Default::default(),
}
@ -119,17 +124,39 @@ impl Deterministic {
T: 'static,
F: Future<Output = T> + 'static,
{
smol::pin!(future);
let unparker = self.parker.lock().unparker();
let woken = Arc::new(AtomicBool::new(false));
let waker = {
let woken = woken.clone();
waker_fn(move || {
let mut future = Box::pin(future);
loop {
if let Some(result) = self.run_internal(woken.clone(), &mut future) {
return result;
}
if !woken.load(SeqCst) && self.state.lock().forbid_parking {
panic!("deterministic executor parked after a call to forbid_parking");
}
woken.store(false, SeqCst);
self.parker.lock().park();
}
}
fn run_until_parked(&self) {
let woken = Arc::new(AtomicBool::new(false));
let future = std::future::pending::<()>();
smol::pin!(future);
self.run_internal(woken, future);
}
pub fn run_internal<F, T>(&self, woken: Arc<AtomicBool>, mut future: F) -> Option<T>
where
T: 'static,
F: Future<Output = T> + Unpin,
{
let unparker = self.parker.lock().unparker();
let waker = waker_fn(move || {
woken.store(true, SeqCst);
unparker.unpark();
})
};
});
let mut cx = Context::from_waker(&waker);
let mut trace = Trace::default();
@ -163,23 +190,17 @@ impl Deterministic {
runnable.run();
} else {
drop(state);
if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
return result;
if let Poll::Ready(result) = future.poll(&mut cx) {
return Some(result);
}
let state = self.state.lock();
if state.scheduled_from_foreground.is_empty()
&& state.scheduled_from_background.is_empty()
&& state.spawned_from_foreground.is_empty()
{
if state.forbid_parking && !woken.load(SeqCst) {
panic!("deterministic executor parked after a call to forbid_parking");
return None;
}
drop(state);
woken.store(false, SeqCst);
self.parker.lock().park();
}
continue;
}
}
}
@ -407,6 +428,41 @@ impl Foreground {
}
}
pub async fn timer(&self, duration: Duration) {
match self {
Self::Deterministic(executor) => {
let (tx, mut rx) = barrier::channel();
{
let mut state = executor.state.lock();
let wakeup_at = state.now + duration;
state.pending_timers.push((wakeup_at, tx));
}
rx.recv().await;
}
_ => {
Timer::after(duration).await;
}
}
}
pub fn advance_clock(&self, duration: Duration) {
match self {
Self::Deterministic(executor) => {
executor.run_until_parked();
let mut state = executor.state.lock();
state.now += duration;
let now = state.now;
let mut pending_timers = mem::take(&mut state.pending_timers);
drop(state);
pending_timers.retain(|(wakeup, _)| *wakeup > now);
executor.state.lock().pending_timers.extend(pending_timers);
}
_ => panic!("this method can only be called on a deterministic executor"),
}
}
pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
match self {
Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,

View file

@ -5,10 +5,7 @@ use super::{
};
use anyhow::anyhow;
use async_std::{sync::RwLock, task};
use async_tungstenite::{
tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
WebSocketStream,
};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use futures::{future::BoxFuture, FutureExt};
use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
use sha1::{Digest as _, Sha1};
@ -30,7 +27,7 @@ use time::OffsetDateTime;
use zrpc::{
auth::random_token,
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
ConnectionId, Peer, TypedEnvelope,
Conn, ConnectionId, Peer, TypedEnvelope,
};
type ReplicaId = u16;
@ -95,6 +92,7 @@ impl Server {
};
server
.add_handler(Server::ping)
.add_handler(Server::share_worktree)
.add_handler(Server::join_worktree)
.add_handler(Server::update_worktree)
@ -133,19 +131,12 @@ impl Server {
self
}
pub fn handle_connection<Conn>(
pub fn handle_connection(
self: &Arc<Self>,
connection: Conn,
addr: String,
user_id: UserId,
) -> impl Future<Output = ()>
where
Conn: 'static
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Send
+ Unpin,
{
) -> impl Future<Output = ()> {
let this = self.clone();
async move {
let (connection_id, handle_io, mut incoming_rx) =
@ -254,6 +245,11 @@ impl Server {
worktree_ids
}
async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
self.peer.respond(request.receipt(), proto::Ack {}).await?;
Ok(())
}
async fn share_worktree(
self: Arc<Server>,
mut request: TypedEnvelope<proto::ShareWorktree>,
@ -503,7 +499,9 @@ impl Server {
request: TypedEnvelope<proto::UpdateBuffer>,
) -> tide::Result<()> {
self.broadcast_in_worktree(request.payload.worktree_id, &request)
.await
.await?;
self.peer.respond(request.receipt(), proto::Ack {}).await?;
Ok(())
}
async fn buffer_saved(
@ -974,8 +972,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
server.handle_connection(stream, addr, user_id).await;
server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
}
});
@ -1009,17 +1006,25 @@ mod tests {
};
use async_std::{sync::RwLockReadGuard, task};
use gpui::TestAppContext;
use postage::mpsc;
use parking_lot::Mutex;
use postage::{mpsc, watch};
use serde_json::json;
use sqlx::types::time::OffsetDateTime;
use std::{path::Path, sync::Arc, time::Duration};
use std::{
path::Path,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
time::Duration,
};
use zed::{
channel::{Channel, ChannelDetails, ChannelList},
editor::{Editor, Insert},
fs::{FakeFs, Fs as _},
language::LanguageRegistry,
rpc::Client,
settings, test,
rpc::{self, Client},
settings,
user::UserStore,
worktree::Worktree,
};
@ -1469,7 +1474,7 @@ mod tests {
.await;
// Drop client B's connection and ensure client A observes client B leaving the worktree.
client_b.disconnect().await.unwrap();
client_b.disconnect(&cx_b.to_async()).await.unwrap();
worktree_a
.condition(&cx_a, |tree, _| tree.peers().len() == 0)
.await;
@ -1675,11 +1680,206 @@ mod tests {
);
}
#[gpui::test]
async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
let mut server = TestServer::start().await;
let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
let mut status_b = client_b.status();
// Create an org that includes these 2 users.
let db = &server.app_state.db;
let org_id = db.create_org("Test Org", "test-org").await.unwrap();
db.add_org_member(org_id, user_id_a, false).await.unwrap();
db.add_org_member(org_id, user_id_b, false).await.unwrap();
// Create a channel that includes all the users.
let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
db.add_channel_member(channel_id, user_id_a, false)
.await
.unwrap();
db.add_channel_member(channel_id, user_id_b, false)
.await
.unwrap();
db.create_channel_message(
channel_id,
user_id_b,
"hello A, it's B.",
OffsetDateTime::now_utc(),
)
.await
.unwrap();
let user_store_a = Arc::new(UserStore::new(client_a.clone()));
let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
channels_a
.condition(&mut cx_a, |list, _| list.available_channels().is_some())
.await;
channels_a.read_with(&cx_a, |list, _| {
assert_eq!(
list.available_channels().unwrap(),
&[ChannelDetails {
id: channel_id.to_proto(),
name: "test-channel".to_string()
}]
)
});
let channel_a = channels_a.update(&mut cx_a, |this, cx| {
this.get_channel(channel_id.to_proto(), cx).unwrap()
});
channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
channel_a
.condition(&cx_a, |channel, _| {
channel_messages(channel)
== [("user_b".to_string(), "hello A, it's B.".to_string())]
})
.await;
let user_store_b = Arc::new(UserStore::new(client_b.clone()));
let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
channels_b
.condition(&mut cx_b, |list, _| list.available_channels().is_some())
.await;
channels_b.read_with(&cx_b, |list, _| {
assert_eq!(
list.available_channels().unwrap(),
&[ChannelDetails {
id: channel_id.to_proto(),
name: "test-channel".to_string()
}]
)
});
let channel_b = channels_b.update(&mut cx_b, |this, cx| {
this.get_channel(channel_id.to_proto(), cx).unwrap()
});
channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
channel_b
.condition(&cx_b, |channel, _| {
channel_messages(channel)
== [("user_b".to_string(), "hello A, it's B.".to_string())]
})
.await;
// Disconnect client B, ensuring we can still access its cached channel data.
server.forbid_connections();
server.disconnect_client(user_id_b);
while !matches!(
status_b.recv().await,
Some(rpc::Status::ReconnectionError { .. })
) {}
channels_b.read_with(&cx_b, |channels, _| {
assert_eq!(
channels.available_channels().unwrap(),
[ChannelDetails {
id: channel_id.to_proto(),
name: "test-channel".to_string()
}]
)
});
channel_b.read_with(&cx_b, |channel, _| {
assert_eq!(
channel_messages(channel),
[("user_b".to_string(), "hello A, it's B.".to_string())]
)
});
// Send a message from client A while B is disconnected.
channel_a
.update(&mut cx_a, |channel, cx| {
channel
.send_message("oh, hi B.".to_string(), cx)
.unwrap()
.detach();
let task = channel.send_message("sup".to_string(), cx).unwrap();
assert_eq!(
channel
.pending_messages()
.iter()
.map(|m| &m.body)
.collect::<Vec<_>>(),
&["oh, hi B.", "sup"]
);
task
})
.await
.unwrap();
// Give client B a chance to reconnect.
server.allow_connections();
cx_b.foreground().advance_clock(Duration::from_secs(10));
// Verify that B sees the new messages upon reconnection.
channel_b
.condition(&cx_b, |channel, _| {
channel_messages(channel)
== [
("user_b".to_string(), "hello A, it's B.".to_string()),
("user_a".to_string(), "oh, hi B.".to_string()),
("user_a".to_string(), "sup".to_string()),
]
})
.await;
// Ensure client A and B can communicate normally after reconnection.
channel_a
.update(&mut cx_a, |channel, cx| {
channel.send_message("you online?".to_string(), cx).unwrap()
})
.await
.unwrap();
channel_b
.condition(&cx_b, |channel, _| {
channel_messages(channel)
== [
("user_b".to_string(), "hello A, it's B.".to_string()),
("user_a".to_string(), "oh, hi B.".to_string()),
("user_a".to_string(), "sup".to_string()),
("user_a".to_string(), "you online?".to_string()),
]
})
.await;
channel_b
.update(&mut cx_b, |channel, cx| {
channel.send_message("yep".to_string(), cx).unwrap()
})
.await
.unwrap();
channel_a
.condition(&cx_a, |channel, _| {
channel_messages(channel)
== [
("user_b".to_string(), "hello A, it's B.".to_string()),
("user_a".to_string(), "oh, hi B.".to_string()),
("user_a".to_string(), "sup".to_string()),
("user_a".to_string(), "you online?".to_string()),
("user_b".to_string(), "yep".to_string()),
]
})
.await;
fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
channel
.messages()
.cursor::<(), ()>()
.map(|m| (m.sender.github_login.clone(), m.body.clone()))
.collect()
}
}
struct TestServer {
peer: Arc<Peer>,
app_state: Arc<AppState>,
server: Arc<Server>,
notifications: mpsc::Receiver<()>,
connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
}
@ -1695,6 +1895,8 @@ mod tests {
app_state,
server,
notifications: notifications.1,
connection_killers: Default::default(),
forbid_connections: Default::default(),
_test_db: test_db,
}
}
@ -1704,20 +1906,67 @@ mod tests {
cx: &mut TestAppContext,
name: &str,
) -> (UserId, Arc<Client>) {
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
let client = Client::new();
let (client_conn, server_conn) = test::Channel::bidirectional();
let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
let client_name = name.to_string();
let mut client = Client::new();
let server = self.server.clone();
let connection_killers = self.connection_killers.clone();
let forbid_connections = self.forbid_connections.clone();
Arc::get_mut(&mut client)
.unwrap()
.set_login_and_connect_callbacks(
move |cx| {
cx.spawn(|_| async move {
let access_token = "the-token".to_string();
Ok((client_user_id.0 as u64, access_token))
})
},
move |user_id, access_token, cx| {
assert_eq!(user_id, client_user_id.0 as u64);
assert_eq!(access_token, "the-token");
let server = server.clone();
let connection_killers = connection_killers.clone();
let forbid_connections = forbid_connections.clone();
let client_name = client_name.clone();
cx.spawn(move |cx| async move {
if forbid_connections.load(SeqCst) {
Err(anyhow!("server is forbidding connections"))
} else {
let (client_conn, server_conn, kill_conn) = Conn::in_memory();
connection_killers.lock().insert(client_user_id, kill_conn);
cx.background()
.spawn(
self.server
.handle_connection(server_conn, name.to_string(), user_id),
)
.spawn(server.handle_connection(
server_conn,
client_name,
client_user_id,
))
.detach();
Ok(client_conn)
}
})
},
);
client
.add_connection(user_id.to_proto(), client_conn, &cx.to_async())
.authenticate_and_connect(&cx.to_async())
.await
.unwrap();
(user_id, client)
(client_user_id, client)
}
fn disconnect_client(&self, user_id: UserId) {
if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
let _ = kill_conn.try_send(Some(()));
}
}
fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}
fn allow_connections(&self) {
self.forbid_connections.store(false, SeqCst);
}
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {

View file

@ -11,6 +11,7 @@ use gpui::{
use postage::prelude::Stream;
use std::{
collections::{HashMap, HashSet},
mem,
ops::Range,
sync::Arc,
};
@ -71,7 +72,7 @@ pub enum ChannelListEvent {}
#[derive(Clone, Debug, PartialEq)]
pub enum ChannelEvent {
MessagesAdded {
MessagesUpdated {
old_range: Range<usize>,
new_count: usize,
},
@ -87,36 +88,47 @@ impl ChannelList {
rpc: Arc<rpc::Client>,
cx: &mut ModelContext<Self>,
) -> Self {
let _task = cx.spawn(|this, mut cx| {
let _task = cx.spawn_weak(|this, mut cx| {
let rpc = rpc.clone();
async move {
let mut user_id = rpc.user_id();
loop {
let available_channels = if user_id.recv().await.unwrap().is_some() {
Some(
rpc.request(proto::GetChannels {})
let mut status = rpc.status();
while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
match status {
rpc::Status::Connected { .. } => {
let response = rpc
.request(proto::GetChannels {})
.await
.context("failed to fetch available channels")?
.channels
.into_iter()
.map(Into::into)
.collect(),
)
} else {
None
};
.context("failed to fetch available channels")?;
this.update(&mut cx, |this, cx| {
if available_channels.is_none() {
if this.available_channels.is_none() {
return;
this.available_channels =
Some(response.channels.into_iter().map(Into::into).collect());
let mut to_remove = Vec::new();
for (channel_id, channel) in &this.channels {
if let Some(channel) = channel.upgrade(cx) {
channel.update(cx, |channel, cx| channel.rejoin(cx))
} else {
to_remove.push(*channel_id);
}
this.channels.clear();
}
this.available_channels = available_channels;
for channel_id in to_remove {
this.channels.remove(&channel_id);
}
cx.notify();
});
}
rpc::Status::Disconnected { .. } => {
this.update(&mut cx, |this, cx| {
this.available_channels = None;
this.channels.clear();
cx.notify();
});
}
_ => {}
}
}
Ok(())
}
.log_err()
});
@ -285,6 +297,43 @@ impl Channel {
false
}
pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
let channel_id = self.details.id;
cx.spawn(|channel, mut cx| {
async move {
let response = rpc.request(proto::JoinChannel { channel_id }).await?;
let messages = messages_from_proto(response.messages, &user_store).await?;
let loaded_all_messages = response.done;
channel.update(&mut cx, |channel, cx| {
if let Some((first_new_message, last_old_message)) =
messages.first().zip(channel.messages.last())
{
if first_new_message.id > last_old_message.id {
let old_messages = mem::take(&mut channel.messages);
cx.emit(ChannelEvent::MessagesUpdated {
old_range: 0..old_messages.summary().count,
new_count: 0,
});
channel.loaded_all_messages = loaded_all_messages;
}
}
channel.insert_messages(messages, cx);
if loaded_all_messages {
channel.loaded_all_messages = loaded_all_messages;
}
});
Ok(())
}
.log_err()
})
.detach();
}
pub fn message_count(&self) -> usize {
self.messages.summary().count
}
@ -350,7 +399,7 @@ impl Channel {
drop(old_cursor);
self.messages = new_messages;
cx.emit(ChannelEvent::MessagesAdded {
cx.emit(ChannelEvent::MessagesUpdated {
old_range: start_ix..end_ix,
new_count,
});
@ -446,22 +495,21 @@ impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
#[cfg(test)]
mod tests {
use super::*;
use crate::test::FakeServer;
use gpui::TestAppContext;
use postage::mpsc::Receiver;
use zrpc::{test::Channel, ConnectionId, Peer, Receipt};
#[gpui::test]
async fn test_channel_messages(mut cx: TestAppContext) {
let user_id = 5;
let client = Client::new();
let mut server = FakeServer::for_client(user_id, &client, &cx).await;
let mut client = Client::new();
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
let user_store = Arc::new(UserStore::new(client.clone()));
let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
// Get the available channels.
let get_channels = server.receive::<proto::GetChannels>().await;
let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
server
.respond(
get_channels.receipt(),
@ -492,7 +540,7 @@ mod tests {
})
.unwrap();
channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
let join_channel = server.receive::<proto::JoinChannel>().await;
let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
server
.respond(
join_channel.receipt(),
@ -517,7 +565,7 @@ mod tests {
.await;
// Client requests all users for the received messages
let mut get_users = server.receive::<proto::GetUsers>().await;
let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
get_users.payload.user_ids.sort();
assert_eq!(get_users.payload.user_ids, vec![5, 6]);
server
@ -542,7 +590,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
ChannelEvent::MessagesAdded {
ChannelEvent::MessagesUpdated {
old_range: 0..0,
new_count: 2,
}
@ -574,7 +622,7 @@ mod tests {
.await;
// Client requests user for message since they haven't seen them yet
let get_users = server.receive::<proto::GetUsers>().await;
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
assert_eq!(get_users.payload.user_ids, vec![7]);
server
.respond(
@ -591,7 +639,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
ChannelEvent::MessagesAdded {
ChannelEvent::MessagesUpdated {
old_range: 2..2,
new_count: 1,
}
@ -610,7 +658,7 @@ mod tests {
channel.update(&mut cx, |channel, cx| {
assert!(channel.load_more_messages(cx));
});
let get_messages = server.receive::<proto::GetChannelMessages>().await;
let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
assert_eq!(get_messages.payload.channel_id, 5);
assert_eq!(get_messages.payload.before_message_id, 10);
server
@ -638,7 +686,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
ChannelEvent::MessagesAdded {
ChannelEvent::MessagesUpdated {
old_range: 0..0,
new_count: 2,
}
@ -656,53 +704,4 @@ mod tests {
);
});
}
struct FakeServer {
peer: Arc<Peer>,
incoming: Receiver<Box<dyn proto::AnyTypedEnvelope>>,
connection_id: ConnectionId,
}
impl FakeServer {
async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
let (client_conn, server_conn) = Channel::bidirectional();
let peer = Peer::new();
let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
client
.add_connection(user_id, client_conn, &cx.to_async())
.await
.unwrap();
Self {
peer,
incoming,
connection_id,
}
}
async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
self.peer.send(self.connection_id, message).await.unwrap();
}
async fn receive<M: proto::EnvelopedMessage>(&mut self) -> TypedEnvelope<M> {
*self
.incoming
.recv()
.await
.unwrap()
.into_any()
.downcast::<TypedEnvelope<M>>()
.unwrap()
}
async fn respond<T: proto::RequestMessage>(
&self,
receipt: Receipt<T>,
response: T::Response,
) {
self.peer.respond(receipt, response).await.unwrap()
}
}
}

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::{
channel::{Channel, ChannelEvent, ChannelList, ChannelMessage},
editor::Editor,
rpc::Client,
rpc::{self, Client},
theme,
util::{ResultExt, TryFutureExt},
Settings,
@ -14,10 +14,10 @@ use gpui::{
keymap::Binding,
platform::CursorStyle,
views::{ItemType, Select, SelectStyle},
AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, View,
AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, Task, View,
ViewContext, ViewHandle,
};
use postage::watch;
use postage::{prelude::Stream, watch};
use time::{OffsetDateTime, UtcOffset};
const MESSAGE_LOADING_THRESHOLD: usize = 50;
@ -31,6 +31,7 @@ pub struct ChatPanel {
channel_select: ViewHandle<Select>,
settings: watch::Receiver<Settings>,
local_timezone: UtcOffset,
_observe_status: Task<()>,
}
pub enum Event {}
@ -98,6 +99,14 @@ impl ChatPanel {
cx.dispatch_action(LoadMoreMessages);
}
});
let _observe_status = cx.spawn(|this, mut cx| {
let mut status = rpc.status();
async move {
while let Some(_) = status.recv().await {
this.update(&mut cx, |_, cx| cx.notify());
}
}
});
let mut this = Self {
rpc,
@ -108,6 +117,7 @@ impl ChatPanel {
channel_select,
settings,
local_timezone: cx.platform().local_timezone(),
_observe_status,
};
this.init_active_channel(cx);
@ -153,6 +163,7 @@ impl ChatPanel {
if let Some(active_channel) = active_channel {
self.set_active_channel(active_channel, cx);
} else {
self.message_list.reset(0);
self.active_channel = None;
}
@ -183,7 +194,7 @@ impl ChatPanel {
cx: &mut ViewContext<Self>,
) {
match event {
ChannelEvent::MessagesAdded {
ChannelEvent::MessagesUpdated {
old_range,
new_count,
} => {
@ -357,10 +368,6 @@ impl ChatPanel {
})
}
}
fn is_signed_in(&self) -> bool {
self.rpc.user_id().borrow().is_some()
}
}
impl Entity for ChatPanel {
@ -374,10 +381,9 @@ impl View for ChatPanel {
fn render(&mut self, cx: &mut RenderContext<Self>) -> ElementBox {
let theme = &self.settings.borrow().theme;
let element = if self.is_signed_in() {
self.render_channel()
} else {
self.render_sign_in_prompt(cx)
let element = match *self.rpc.status().borrow() {
rpc::Status::Connected { .. } => self.render_channel(),
_ => self.render_sign_in_prompt(cx),
};
ConstrainedBox::new(
Container::new(element)
@ -389,7 +395,7 @@ impl View for ChatPanel {
}
fn on_focus(&mut self, cx: &mut ViewContext<Self>) {
if self.is_signed_in() {
if matches!(*self.rpc.status().borrow(), rpc::Status::Connected { .. }) {
cx.focus(&self.input_editor);
}
}

View file

@ -2695,14 +2695,7 @@ impl<'a> Into<proto::operation::Edit> for &'a EditOperation {
impl<'a> Into<proto::Anchor> for &'a Anchor {
fn into(self) -> proto::Anchor {
proto::Anchor {
version: self
.version
.iter()
.map(|entry| proto::VectorClockEntry {
replica_id: entry.replica_id as u32,
timestamp: entry.value,
})
.collect(),
version: (&self.version).into(),
offset: self.offset as u64,
bias: match self.bias {
Bias::Left => proto::anchor::Bias::Left as i32,

View file

@ -1,24 +1,24 @@
use crate::util::ResultExt;
use anyhow::{anyhow, Context, Result};
use async_tungstenite::tungstenite::http::Request;
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
use lazy_static::lazy_static;
use parking_lot::RwLock;
use postage::prelude::Stream;
use postage::sink::Sink;
use postage::watch;
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::Weak;
use std::time::{Duration, Instant};
use std::{convert::TryFrom, future::Future, sync::Arc};
use postage::{prelude::Stream, watch};
use rand::prelude::*;
use std::{
any::TypeId,
collections::HashMap,
convert::TryFrom,
future::Future,
sync::{Arc, Weak},
time::{Duration, Instant},
};
use surf::Url;
use zrpc::proto::{AnyTypedEnvelope, EntityMessage};
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
use zrpc::{
proto::{EnvelopedMessage, RequestMessage},
Peer, Receipt,
proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
Conn, Peer, Receipt,
};
lazy_static! {
@ -29,25 +29,55 @@ lazy_static! {
pub struct Client {
peer: Arc<Peer>,
state: RwLock<ClientState>,
auth_callback: Option<
Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
>,
connect_callback: Option<
Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
>,
}
#[derive(Copy, Clone, Debug)]
pub enum Status {
Disconnected,
Authenticating,
Connecting {
user_id: u64,
},
ConnectionError,
Connected {
connection_id: ConnectionId,
user_id: u64,
},
ConnectionLost,
Reauthenticating,
Reconnecting {
user_id: u64,
},
ReconnectionError {
next_reconnection: Instant,
},
}
struct ClientState {
connection_id: Option<ConnectionId>,
user_id: (watch::Sender<Option<u64>>, watch::Receiver<Option<u64>>),
status: (watch::Sender<Status>, watch::Receiver<Status>),
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
model_handlers: HashMap<
(TypeId, u64),
Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
>,
_maintain_connection: Option<Task<()>>,
heartbeat_interval: Duration,
}
impl Default for ClientState {
fn default() -> Self {
Self {
connection_id: Default::default(),
user_id: watch::channel(),
status: watch::channel_with(Status::Disconnected),
entity_id_extractors: Default::default(),
model_handlers: Default::default(),
_maintain_connection: None,
heartbeat_interval: Duration::from_secs(5),
}
}
}
@ -77,11 +107,71 @@ impl Client {
Arc::new(Self {
peer: Peer::new(),
state: Default::default(),
auth_callback: None,
connect_callback: None,
})
}
pub fn user_id(&self) -> watch::Receiver<Option<u64>> {
self.state.read().user_id.1.clone()
#[cfg(any(test, feature = "test-support"))]
pub fn set_login_and_connect_callbacks<Login, Connect>(
&mut self,
login: Login,
connect: Connect,
) where
Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
{
self.auth_callback = Some(Box::new(login));
self.connect_callback = Some(Box::new(connect));
}
pub fn status(&self) -> watch::Receiver<Status> {
self.state.read().status.1.clone()
}
fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
let mut state = self.state.write();
*state.status.0.borrow_mut() = status;
match status {
Status::Connected { .. } => {
let heartbeat_interval = state.heartbeat_interval;
let this = self.clone();
let foreground = cx.foreground();
state._maintain_connection = Some(cx.foreground().spawn(async move {
loop {
foreground.timer(heartbeat_interval).await;
this.request(proto::Ping {}).await.unwrap();
}
}));
}
Status::ConnectionLost => {
let this = self.clone();
let foreground = cx.foreground();
let heartbeat_interval = state.heartbeat_interval;
state._maintain_connection = Some(cx.spawn(|cx| async move {
let mut rng = StdRng::from_entropy();
let mut delay = Duration::from_millis(100);
while let Err(error) = this.authenticate_and_connect(&cx).await {
log::error!("failed to connect {}", error);
this.set_status(
Status::ReconnectionError {
next_reconnection: Instant::now() + delay,
},
&cx,
);
foreground.timer(delay).await;
delay = delay
.mul_f32(rng.gen_range(1.0..=2.0))
.min(heartbeat_interval);
}
}));
}
Status::Disconnected => {
state._maintain_connection.take();
}
_ => {}
}
}
pub fn subscribe_from_model<T, M, F>(
@ -141,56 +231,57 @@ impl Client {
self: &Arc<Self>,
cx: &AsyncAppContext,
) -> anyhow::Result<()> {
if self.state.read().connection_id.is_some() {
return Ok(());
let was_disconnected = match *self.status().borrow() {
Status::Disconnected => true,
Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
false
}
let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
let user_id = user_id.parse::<u64>()?;
let request =
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
.await
.context("websocket handshake")?;
self.add_connection(user_id, stream, cx).await?;
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::client_async(request, stream)
.await
.context("websocket handshake")?;
self.add_connection(user_id, stream, cx).await?;
} else {
return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
Status::Connected { .. }
| Status::Connecting { .. }
| Status::Reconnecting { .. }
| Status::Authenticating
| Status::Reauthenticating => return Ok(()),
};
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
Ok(())
if was_disconnected {
self.set_status(Status::Authenticating, cx);
} else {
self.set_status(Status::Reauthenticating, cx)
}
pub async fn add_connection<Conn>(
self: &Arc<Self>,
user_id: u64,
conn: Conn,
cx: &AsyncAppContext,
) -> anyhow::Result<()>
where
Conn: 'static
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Unpin
+ Send,
{
let (user_id, access_token) = match self.authenticate(&cx).await {
Ok(result) => result,
Err(err) => {
self.set_status(Status::ConnectionError, cx);
return Err(err);
}
};
if was_disconnected {
self.set_status(Status::Connecting { user_id }, cx);
} else {
self.set_status(Status::Reconnecting { user_id }, cx);
}
match self.connect(user_id, &access_token, cx).await {
Ok(conn) => {
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
self.set_connection(user_id, conn, cx).await;
Ok(())
}
Err(err) => {
self.set_status(Status::ConnectionError, cx);
Err(err)
}
}
}
async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
{
cx.foreground()
.spawn({
let mut cx = cx.clone();
let this = self.clone();
cx.foreground()
.spawn(async move {
async move {
while let Some(message) = incoming.recv().await {
let mut state = this.state.write();
if let Some(extract_entity_id) =
@ -215,27 +306,90 @@ impl Client {
log::info!("unhandled message {}", message.payload_type_name());
}
}
}
})
.detach();
}
cx.background()
self.set_status(
Status::Connected {
connection_id,
user_id,
},
cx,
);
let handle_io = cx.background().spawn(handle_io);
let this = self.clone();
let cx = cx.clone();
cx.foreground()
.spawn(async move {
if let Err(error) = handle_io.await {
log::error!("connection error: {:?}", error);
match handle_io.await {
Ok(()) => this.set_status(Status::Disconnected, &cx),
Err(err) => {
log::error!("connection error: {:?}", err);
this.set_status(Status::ConnectionLost, &cx);
}
}
})
.detach();
let mut state = self.state.write();
state.connection_id = Some(connection_id);
state.user_id.0.send(Some(user_id)).await?;
Ok(())
}
pub fn login(
platform: Arc<dyn gpui::Platform>,
executor: &Arc<gpui::executor::Background>,
) -> Task<Result<(String, String)>> {
let executor = executor.clone();
fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
if let Some(callback) = self.auth_callback.as_ref() {
callback(cx)
} else {
self.authenticate_with_browser(cx)
}
}
fn connect(
self: &Arc<Self>,
user_id: u64,
access_token: &str,
cx: &AsyncAppContext,
) -> Task<Result<Conn>> {
if let Some(callback) = self.connect_callback.as_ref() {
callback(user_id, access_token, cx)
} else {
self.connect_with_websocket(user_id, access_token, cx)
}
}
fn connect_with_websocket(
self: &Arc<Self>,
user_id: u64,
access_token: &str,
cx: &AsyncAppContext,
) -> Task<Result<Conn>> {
let request =
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
cx.background().spawn(async move {
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
.await
.context("websocket handshake")?;
Ok(Conn::new(stream))
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::client_async(request, stream)
.await
.context("websocket handshake")?;
Ok(Conn::new(stream))
} else {
Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
}
})
}
pub fn authenticate_with_browser(
self: &Arc<Self>,
cx: &AsyncAppContext,
) -> Task<Result<(u64, String)>> {
let platform = cx.platform();
let executor = cx.background();
executor.clone().spawn(async move {
if let Some((user_id, access_token)) = platform
.read_credentials(&ZED_SERVER_URL)
@ -243,7 +397,7 @@ impl Client {
.flatten()
{
log::info!("already signed in. user_id: {}", user_id);
return Ok((user_id, String::from_utf8(access_token).unwrap()));
return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
}
// Generate a pair of asymmetric encryption keys. The public key will be used by the
@ -309,21 +463,23 @@ impl Client {
platform
.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
.log_err();
Ok((user_id.to_string(), access_token))
Ok((user_id.parse()?, access_token))
})
}
pub async fn disconnect(&self) -> Result<()> {
pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
let conn_id = self.connection_id()?;
self.peer.disconnect(conn_id).await;
self.set_status(Status::Disconnected, cx);
Ok(())
}
fn connection_id(&self) -> Result<ConnectionId> {
self.state
.read()
.connection_id
.ok_or_else(|| anyhow!("not connected"))
if let Status::Connected { connection_id, .. } = *self.status().borrow() {
Ok(connection_id)
} else {
Err(anyhow!("not connected"))
}
}
pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
@ -343,35 +499,6 @@ impl Client {
}
}
pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
type Output: 'a + Future<Output = anyhow::Result<()>>;
fn handle(
&self,
message: TypedEnvelope<M>,
rpc: &'a Client,
cx: &'a mut gpui::AsyncAppContext,
) -> Self::Output;
}
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
where
M: proto::EnvelopedMessage,
F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
Fut: 'a + Future<Output = anyhow::Result<()>>,
{
type Output = Fut;
fn handle(
&self,
message: TypedEnvelope<M>,
rpc: &'a Client,
cx: &'a mut gpui::AsyncAppContext,
) -> Self::Output {
(self)(message, rpc, cx)
}
}
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
@ -396,6 +523,54 @@ const LOGIN_RESPONSE: &'static str = "
</html>
";
#[cfg(test)]
mod tests {
use super::*;
use crate::test::FakeServer;
use gpui::TestAppContext;
#[gpui::test(iterations = 10)]
async fn test_heartbeat(cx: TestAppContext) {
cx.foreground().forbid_parking();
let user_id = 5;
let mut client = Client::new();
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
cx.foreground().advance_clock(Duration::from_secs(10));
let ping = server.receive::<proto::Ping>().await.unwrap();
server.respond(ping.receipt(), proto::Ack {}).await;
cx.foreground().advance_clock(Duration::from_secs(10));
let ping = server.receive::<proto::Ping>().await.unwrap();
server.respond(ping.receipt(), proto::Ack {}).await;
client.disconnect(&cx.to_async()).await.unwrap();
assert!(server.receive::<proto::Ping>().await.is_err());
}
#[gpui::test(iterations = 10)]
async fn test_reconnection(cx: TestAppContext) {
cx.foreground().forbid_parking();
let user_id = 5;
let mut client = Client::new();
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
let mut status = client.status();
assert!(matches!(
status.recv().await,
Some(Status::Connected { .. })
));
server.forbid_connections();
server.disconnect().await;
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
server.allow_connections();
cx.foreground().advance_clock(Duration::from_secs(10));
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
}
#[test]
fn test_encode_and_decode_worktree_url() {
let url = encode_worktree_url(5, "deadbeef");
@ -406,3 +581,4 @@ fn test_encode_and_decode_worktree_url() {
);
assert_eq!(decode_worktree_url("not://the-right-format"), None);
}
}

View file

@ -3,24 +3,27 @@ use crate::{
channel::ChannelList,
fs::RealFs,
language::LanguageRegistry,
rpc,
rpc::{self, Client},
settings::{self, ThemeRegistry},
time::ReplicaId,
user::UserStore,
AppState,
};
use gpui::{Entity, ModelHandle, MutableAppContext};
use anyhow::{anyhow, Result};
use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
use parking_lot::Mutex;
use postage::{mpsc, prelude::Stream as _};
use smol::channel;
use std::{
marker::PhantomData,
path::{Path, PathBuf},
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use tempdir::TempDir;
#[cfg(feature = "test-support")]
pub use zrpc::test::Channel;
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
#[cfg(test)]
#[ctor::ctor]
@ -195,3 +198,117 @@ impl<T: Entity> Observer<T> {
(observer, notify_rx)
}
}
pub struct FakeServer {
peer: Arc<Peer>,
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
connection_id: Mutex<Option<ConnectionId>>,
forbid_connections: AtomicBool,
}
impl FakeServer {
pub async fn for_client(
client_user_id: u64,
client: &mut Arc<Client>,
cx: &TestAppContext,
) -> Arc<Self> {
let result = Arc::new(Self {
peer: Peer::new(),
incoming: Default::default(),
connection_id: Default::default(),
forbid_connections: Default::default(),
});
Arc::get_mut(client)
.unwrap()
.set_login_and_connect_callbacks(
move |cx| {
cx.spawn(|_| async move {
let access_token = "the-token".to_string();
Ok((client_user_id, access_token))
})
},
{
let server = result.clone();
move |user_id, access_token, cx| {
assert_eq!(user_id, client_user_id);
assert_eq!(access_token, "the-token");
cx.spawn({
let server = server.clone();
move |cx| async move { server.connect(&cx).await }
})
}
},
);
client
.authenticate_and_connect(&cx.to_async())
.await
.unwrap();
result
}
pub async fn disconnect(&self) {
self.peer.disconnect(self.connection_id()).await;
self.connection_id.lock().take();
self.incoming.lock().take();
}
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
if self.forbid_connections.load(SeqCst) {
Err(anyhow!("server is forbidding connections"))
} else {
let (client_conn, server_conn, _) = Conn::in_memory();
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
cx.background().spawn(io).detach();
*self.incoming.lock() = Some(incoming);
*self.connection_id.lock() = Some(connection_id);
Ok(client_conn)
}
}
pub fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}
pub fn allow_connections(&self) {
self.forbid_connections.store(false, SeqCst);
}
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
self.peer.send(self.connection_id(), message).await.unwrap();
}
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
let message = self
.incoming
.lock()
.as_mut()
.expect("not connected")
.recv()
.await
.ok_or_else(|| anyhow!("other half hung up"))?;
let type_name = message.payload_type_name();
Ok(*message
.into_any()
.downcast::<TypedEnvelope<M>>()
.unwrap_or_else(|_| {
panic!(
"fake server received unexpected message type: {:?}",
type_name
);
}))
}
pub async fn respond<T: proto::RequestMessage>(
&self,
receipt: Receipt<T>,
response: T::Response,
) {
self.peer.respond(receipt, response).await.unwrap()
}
fn connection_id(&self) -> ConnectionId {
self.connection_id.lock().expect("not connected")
}
}

View file

@ -234,6 +234,7 @@ impl Worktree {
.into_iter()
.map(|p| (PeerId(p.peer_id), p.replica_id as ReplicaId))
.collect(),
queued_operations: Default::default(),
languages,
_subscriptions,
})
@ -656,6 +657,7 @@ pub struct LocalWorktree {
shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
peers: HashMap<PeerId, ReplicaId>,
languages: Arc<LanguageRegistry>,
queued_operations: Vec<(u64, Operation)>,
fs: Arc<dyn Fs>,
}
@ -711,6 +713,7 @@ impl LocalWorktree {
poll_task: None,
open_buffers: Default::default(),
shared_buffers: Default::default(),
queued_operations: Default::default(),
peers: Default::default(),
languages,
fs,
@ -1091,6 +1094,7 @@ pub struct RemoteWorktree {
open_buffers: HashMap<usize, RemoteBuffer>,
peers: HashMap<PeerId, ReplicaId>,
languages: Arc<LanguageRegistry>,
queued_operations: Vec<(u64, Operation)>,
_subscriptions: Vec<rpc::Subscription>,
}
@ -1550,16 +1554,23 @@ impl File {
.map(|share| (share.rpc.clone(), share.remote_id)),
Worktree::Remote(worktree) => Some((worktree.rpc.clone(), worktree.remote_id)),
} {
cx.spawn(|_, _| async move {
cx.spawn(|worktree, mut cx| async move {
if let Err(error) = rpc
.send(proto::UpdateBuffer {
.request(proto::UpdateBuffer {
worktree_id: remote_id,
buffer_id,
operations: Some(operation).iter().map(Into::into).collect(),
operations: vec![(&operation).into()],
})
.await
{
worktree.update(&mut cx, |worktree, _| {
log::error!("error sending buffer operation: {}", error);
match worktree {
Worktree::Local(t) => &mut t.queued_operations,
Worktree::Remote(t) => &mut t.queued_operations,
}
.push((buffer_id, operation));
});
}
})
.detach();
@ -1582,7 +1593,7 @@ impl File {
.await
{
log::error!("error closing remote buffer: {}", error);
};
}
})
.detach();
}

View file

@ -6,9 +6,9 @@ message Envelope {
optional uint32 responding_to = 2;
optional uint32 original_sender_id = 3;
oneof payload {
Error error = 4;
Ping ping = 5;
Pong pong = 6;
Ack ack = 4;
Error error = 5;
Ping ping = 6;
ShareWorktree share_worktree = 7;
ShareWorktreeResponse share_worktree_response = 8;
OpenWorktree open_worktree = 9;
@ -40,13 +40,9 @@ message Envelope {
// Messages
message Ping {
int32 id = 1;
}
message Ping {}
message Pong {
int32 id = 2;
}
message Ack {}
message Error {
string message = 1;

101
zrpc/src/conn.rs Normal file
View file

@ -0,0 +1,101 @@
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
use std::{io, task::Poll};
pub struct Conn {
pub(crate) tx:
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
pub(crate) rx: Box<
dyn 'static
+ Send
+ Unpin
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
>,
}
impl Conn {
pub fn new<S>(stream: S) -> Self
where
S: 'static
+ Send
+ Unpin
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
{
let (tx, rx) = stream.split();
Self {
tx: Box::new(tx),
rx: Box::new(rx),
}
}
pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
self.tx.send(message).await
}
#[cfg(any(test, feature = "test-support"))]
pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
let (a_tx, a_rx) = Self::channel(kill_rx.clone());
let (b_tx, b_rx) = Self::channel(kill_rx);
(
Self { tx: a_tx, rx: b_rx },
Self { tx: b_tx, rx: a_rx },
kill_tx,
)
}
#[cfg(any(test, feature = "test-support"))]
fn channel(
kill_rx: postage::watch::Receiver<Option<()>>,
) -> (
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
) {
use futures::{future, SinkExt as _};
use io::{Error, ErrorKind};
let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
let tx = tx
.sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
.with({
let kill_rx = kill_rx.clone();
move |msg| {
if kill_rx.borrow().is_none() {
future::ready(Ok(msg))
} else {
future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
}
}
});
let rx = KillableReceiver { kill_rx, rx };
(Box::new(tx), Box::new(rx))
}
}
struct KillableReceiver {
rx: mpsc::UnboundedReceiver<WebSocketMessage>,
kill_rx: postage::watch::Receiver<Option<()>>,
}
impl Stream for KillableReceiver {
type Item = Result<WebSocketMessage, WebSocketError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::Other,
"connection killed",
)
.into())))
} else {
self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
}
}
}

View file

@ -1,7 +1,6 @@
pub mod auth;
mod conn;
mod peer;
pub mod proto;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
pub use conn::Conn;
pub use peer::*;

View file

@ -1,8 +1,8 @@
use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
use super::Conn;
use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::{FutureExt, StreamExt};
use futures::FutureExt as _;
use postage::{
mpsc,
prelude::{Sink as _, Stream as _},
@ -98,21 +98,14 @@ impl Peer {
})
}
pub async fn add_connection<Conn>(
pub async fn add_connection(
self: &Arc<Self>,
conn: Conn,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
)
where
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Send
+ Unpin,
{
let (tx, rx) = conn.split();
) {
let connection_id = ConnectionId(
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
@ -124,9 +117,10 @@ impl Peer {
next_message_id: Default::default(),
response_channels: Default::default(),
};
let mut writer = MessageStream::new(tx);
let mut reader = MessageStream::new(rx);
let mut writer = MessageStream::new(conn.tx);
let mut reader = MessageStream::new(conn.rx);
let this = self.clone();
let response_channels = connection.response_channels.clone();
let handle_io = async move {
loop {
@ -147,6 +141,7 @@ impl Peer {
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
if incoming_tx.send(envelope).await.is_err() {
response_channels.lock().await.clear();
this.connections.write().await.remove(&connection_id);
return Ok(())
}
} else {
@ -158,6 +153,7 @@ impl Peer {
}
Err(error) => {
response_channels.lock().await.clear();
this.connections.write().await.remove(&connection_id);
Err(error).context("received invalid RPC message")?;
}
},
@ -165,11 +161,13 @@ impl Peer {
Some(outgoing) => {
if let Err(result) = writer.write_message(&outgoing).await {
response_channels.lock().await.clear();
this.connections.write().await.remove(&connection_id);
Err(result).context("failed to write RPC message")?;
}
}
None => {
response_channels.lock().await.clear();
this.connections.write().await.remove(&connection_id);
return Ok(())
}
}
@ -342,7 +340,9 @@ impl Peer {
#[cfg(test)]
mod tests {
use super::*;
use crate::{test, TypedEnvelope};
use crate::TypedEnvelope;
use async_tungstenite::tungstenite::Message as WebSocketMessage;
use futures::StreamExt as _;
#[test]
fn test_request_response() {
@ -352,12 +352,12 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
let (client1_conn_id, io_task1, _) =
client1.add_connection(client1_to_server_conn).await;
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
let (client2_conn_id, io_task3, _) =
client2.add_connection(client2_to_server_conn).await;
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@ -371,18 +371,18 @@ mod tests {
assert_eq!(
client1
.request(client1_conn_id, proto::Ping { id: 1 },)
.request(client1_conn_id, proto::Ping {},)
.await
.unwrap(),
proto::Pong { id: 1 }
proto::Ack {}
);
assert_eq!(
client2
.request(client2_conn_id, proto::Ping { id: 2 },)
.request(client2_conn_id, proto::Ping {},)
.await
.unwrap(),
proto::Pong { id: 2 }
proto::Ack {}
);
assert_eq!(
@ -438,13 +438,7 @@ mod tests {
let envelope = envelope.into_any();
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
let receipt = envelope.receipt();
peer.respond(
receipt,
proto::Pong {
id: envelope.payload.id,
},
)
.await?
peer.respond(receipt, proto::Ack {}).await?
} else if let Some(envelope) =
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
{
@ -492,7 +486,7 @@ mod tests {
#[test]
fn test_disconnect() {
smol::block_on(async move {
let (client_conn, mut server_conn) = test::Channel::bidirectional();
let (client_conn, mut server_conn, _) = Conn::in_memory();
let client = Peer::new();
let (connection_id, io_handler, mut incoming) =
@ -516,18 +510,17 @@ mod tests {
io_ended_rx.recv().await;
messages_ended_rx.recv().await;
assert!(
futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
assert!(server_conn
.send(WebSocketMessage::Binary(vec![]))
.await
.is_err()
);
.is_err());
});
}
#[test]
fn test_io_error() {
smol::block_on(async move {
let (client_conn, server_conn) = test::Channel::bidirectional();
let (client_conn, server_conn, _) = Conn::in_memory();
drop(server_conn);
let client = Peer::new();
@ -537,7 +530,7 @@ mod tests {
smol::spawn(async move { incoming.next().await }).detach();
let err = client
.request(connection_id, proto::Ping { id: 42 })
.request(connection_id, proto::Ping {})
.await
.unwrap_err();
assert_eq!(err.to_string(), "connection was closed");

View file

@ -120,6 +120,7 @@ macro_rules! entity_messages {
}
messages!(
Ack,
AddPeer,
BufferSaved,
ChannelMessageSent,
@ -140,7 +141,6 @@ messages!(
OpenWorktree,
OpenWorktreeResponse,
Ping,
Pong,
RemovePeer,
SaveBuffer,
SendChannelMessage,
@ -157,8 +157,9 @@ request_messages!(
(JoinChannel, JoinChannelResponse),
(OpenBuffer, OpenBufferResponse),
(OpenWorktree, OpenWorktreeResponse),
(Ping, Pong),
(Ping, Ack),
(SaveBuffer, BufferSaved),
(UpdateBuffer, Ack),
(ShareWorktree, ShareWorktreeResponse),
(SendChannelMessage, SendChannelMessageResponse),
(GetChannelMessages, GetChannelMessagesResponse),
@ -247,30 +248,3 @@ impl From<SystemTime> for Timestamp {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test;
#[test]
fn test_round_trip_message() {
smol::block_on(async {
let stream = test::Channel::new();
let message1 = Ping { id: 5 }.into_envelope(3, None, None);
let message2 = OpenBuffer {
worktree_id: 0,
path: "some/path".to_string(),
}
.into_envelope(5, None, None);
let mut message_stream = MessageStream::new(stream);
message_stream.write_message(&message1).await.unwrap();
message_stream.write_message(&message2).await.unwrap();
let decoded_message1 = message_stream.read_message().await.unwrap();
let decoded_message2 = message_stream.read_message().await.unwrap();
assert_eq!(decoded_message1, message1);
assert_eq!(decoded_message2, message2);
});
}
}

View file

@ -1,64 +0,0 @@
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
pub struct Channel {
tx: futures::channel::mpsc::UnboundedSender<WebSocketMessage>,
rx: futures::channel::mpsc::UnboundedReceiver<WebSocketMessage>,
}
impl Channel {
pub fn new() -> Self {
let (tx, rx) = futures::channel::mpsc::unbounded();
Self { tx, rx }
}
pub fn bidirectional() -> (Self, Self) {
let (a_tx, a_rx) = futures::channel::mpsc::unbounded();
let (b_tx, b_rx) = futures::channel::mpsc::unbounded();
let a = Self { tx: a_tx, rx: b_rx };
let b = Self { tx: b_tx, rx: a_rx };
(a, b)
}
}
impl futures::Sink<WebSocketMessage> for Channel {
type Error = WebSocketError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.tx)
.poll_ready(cx)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
}
fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> {
Pin::new(&mut self.tx)
.start_send(item)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.tx)
.poll_flush(cx)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.tx)
.poll_close(cx)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
}
}
impl futures::Stream for Channel {
type Item = Result<WebSocketMessage, WebSocketError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.rx)
.poll_next(cx)
.map(|i| i.map(|i| Ok(i)))
}
}