Merge commit '680b86b17c
' into main
This commit is contained in:
commit
3d4a451c15
14 changed files with 1045 additions and 439 deletions
|
@ -3,8 +3,9 @@ use async_task::Runnable;
|
|||
pub use async_task::Task;
|
||||
use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
|
||||
use parking_lot::Mutex;
|
||||
use postage::{barrier, prelude::Stream as _};
|
||||
use rand::prelude::*;
|
||||
use smol::{channel, prelude::*, Executor};
|
||||
use smol::{channel, prelude::*, Executor, Timer};
|
||||
use std::{
|
||||
fmt::{self, Debug},
|
||||
marker::PhantomData,
|
||||
|
@ -18,7 +19,7 @@ use std::{
|
|||
},
|
||||
task::{Context, Poll},
|
||||
thread,
|
||||
time::Duration,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use waker_fn::waker_fn;
|
||||
|
||||
|
@ -49,6 +50,8 @@ struct DeterministicState {
|
|||
spawned_from_foreground: Vec<(Runnable, Backtrace)>,
|
||||
forbid_parking: bool,
|
||||
block_on_ticks: RangeInclusive<usize>,
|
||||
now: Instant,
|
||||
pending_timers: Vec<(Instant, barrier::Sender)>,
|
||||
}
|
||||
|
||||
pub struct Deterministic {
|
||||
|
@ -67,6 +70,8 @@ impl Deterministic {
|
|||
spawned_from_foreground: Default::default(),
|
||||
forbid_parking: false,
|
||||
block_on_ticks: 0..=1000,
|
||||
now: Instant::now(),
|
||||
pending_timers: Default::default(),
|
||||
})),
|
||||
parker: Default::default(),
|
||||
}
|
||||
|
@ -119,17 +124,39 @@ impl Deterministic {
|
|||
T: 'static,
|
||||
F: Future<Output = T> + 'static,
|
||||
{
|
||||
smol::pin!(future);
|
||||
|
||||
let unparker = self.parker.lock().unparker();
|
||||
let woken = Arc::new(AtomicBool::new(false));
|
||||
let waker = {
|
||||
let woken = woken.clone();
|
||||
waker_fn(move || {
|
||||
let mut future = Box::pin(future);
|
||||
loop {
|
||||
if let Some(result) = self.run_internal(woken.clone(), &mut future) {
|
||||
return result;
|
||||
}
|
||||
|
||||
if !woken.load(SeqCst) && self.state.lock().forbid_parking {
|
||||
panic!("deterministic executor parked after a call to forbid_parking");
|
||||
}
|
||||
|
||||
woken.store(false, SeqCst);
|
||||
self.parker.lock().park();
|
||||
}
|
||||
}
|
||||
|
||||
fn run_until_parked(&self) {
|
||||
let woken = Arc::new(AtomicBool::new(false));
|
||||
let future = std::future::pending::<()>();
|
||||
smol::pin!(future);
|
||||
self.run_internal(woken, future);
|
||||
}
|
||||
|
||||
pub fn run_internal<F, T>(&self, woken: Arc<AtomicBool>, mut future: F) -> Option<T>
|
||||
where
|
||||
T: 'static,
|
||||
F: Future<Output = T> + Unpin,
|
||||
{
|
||||
let unparker = self.parker.lock().unparker();
|
||||
let waker = waker_fn(move || {
|
||||
woken.store(true, SeqCst);
|
||||
unparker.unpark();
|
||||
})
|
||||
};
|
||||
});
|
||||
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut trace = Trace::default();
|
||||
|
@ -163,23 +190,17 @@ impl Deterministic {
|
|||
runnable.run();
|
||||
} else {
|
||||
drop(state);
|
||||
if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
|
||||
return result;
|
||||
if let Poll::Ready(result) = future.poll(&mut cx) {
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
let state = self.state.lock();
|
||||
if state.scheduled_from_foreground.is_empty()
|
||||
&& state.scheduled_from_background.is_empty()
|
||||
&& state.spawned_from_foreground.is_empty()
|
||||
{
|
||||
if state.forbid_parking && !woken.load(SeqCst) {
|
||||
panic!("deterministic executor parked after a call to forbid_parking");
|
||||
return None;
|
||||
}
|
||||
drop(state);
|
||||
woken.store(false, SeqCst);
|
||||
self.parker.lock().park();
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -407,6 +428,41 @@ impl Foreground {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn timer(&self, duration: Duration) {
|
||||
match self {
|
||||
Self::Deterministic(executor) => {
|
||||
let (tx, mut rx) = barrier::channel();
|
||||
{
|
||||
let mut state = executor.state.lock();
|
||||
let wakeup_at = state.now + duration;
|
||||
state.pending_timers.push((wakeup_at, tx));
|
||||
}
|
||||
rx.recv().await;
|
||||
}
|
||||
_ => {
|
||||
Timer::after(duration).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn advance_clock(&self, duration: Duration) {
|
||||
match self {
|
||||
Self::Deterministic(executor) => {
|
||||
executor.run_until_parked();
|
||||
|
||||
let mut state = executor.state.lock();
|
||||
state.now += duration;
|
||||
let now = state.now;
|
||||
let mut pending_timers = mem::take(&mut state.pending_timers);
|
||||
drop(state);
|
||||
|
||||
pending_timers.retain(|(wakeup, _)| *wakeup > now);
|
||||
executor.state.lock().pending_timers.extend(pending_timers);
|
||||
}
|
||||
_ => panic!("this method can only be called on a deterministic executor"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
|
||||
match self {
|
||||
Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
|
||||
|
|
|
@ -5,10 +5,7 @@ use super::{
|
|||
};
|
||||
use anyhow::anyhow;
|
||||
use async_std::{sync::RwLock, task};
|
||||
use async_tungstenite::{
|
||||
tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
|
||||
WebSocketStream,
|
||||
};
|
||||
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
|
||||
use futures::{future::BoxFuture, FutureExt};
|
||||
use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
|
||||
use sha1::{Digest as _, Sha1};
|
||||
|
@ -30,7 +27,7 @@ use time::OffsetDateTime;
|
|||
use zrpc::{
|
||||
auth::random_token,
|
||||
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
|
||||
ConnectionId, Peer, TypedEnvelope,
|
||||
Conn, ConnectionId, Peer, TypedEnvelope,
|
||||
};
|
||||
|
||||
type ReplicaId = u16;
|
||||
|
@ -95,6 +92,7 @@ impl Server {
|
|||
};
|
||||
|
||||
server
|
||||
.add_handler(Server::ping)
|
||||
.add_handler(Server::share_worktree)
|
||||
.add_handler(Server::join_worktree)
|
||||
.add_handler(Server::update_worktree)
|
||||
|
@ -133,19 +131,12 @@ impl Server {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn handle_connection<Conn>(
|
||||
pub fn handle_connection(
|
||||
self: &Arc<Self>,
|
||||
connection: Conn,
|
||||
addr: String,
|
||||
user_id: UserId,
|
||||
) -> impl Future<Output = ()>
|
||||
where
|
||||
Conn: 'static
|
||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
||||
+ Send
|
||||
+ Unpin,
|
||||
{
|
||||
) -> impl Future<Output = ()> {
|
||||
let this = self.clone();
|
||||
async move {
|
||||
let (connection_id, handle_io, mut incoming_rx) =
|
||||
|
@ -254,6 +245,11 @@ impl Server {
|
|||
worktree_ids
|
||||
}
|
||||
|
||||
async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
|
||||
self.peer.respond(request.receipt(), proto::Ack {}).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn share_worktree(
|
||||
self: Arc<Server>,
|
||||
mut request: TypedEnvelope<proto::ShareWorktree>,
|
||||
|
@ -503,7 +499,9 @@ impl Server {
|
|||
request: TypedEnvelope<proto::UpdateBuffer>,
|
||||
) -> tide::Result<()> {
|
||||
self.broadcast_in_worktree(request.payload.worktree_id, &request)
|
||||
.await
|
||||
.await?;
|
||||
self.peer.respond(request.receipt(), proto::Ack {}).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn buffer_saved(
|
||||
|
@ -974,8 +972,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
|||
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
|
||||
task::spawn(async move {
|
||||
if let Some(stream) = upgrade_receiver.await {
|
||||
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
|
||||
server.handle_connection(stream, addr, user_id).await;
|
||||
server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -1009,17 +1006,25 @@ mod tests {
|
|||
};
|
||||
use async_std::{sync::RwLockReadGuard, task};
|
||||
use gpui::TestAppContext;
|
||||
use postage::mpsc;
|
||||
use parking_lot::Mutex;
|
||||
use postage::{mpsc, watch};
|
||||
use serde_json::json;
|
||||
use sqlx::types::time::OffsetDateTime;
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use zed::{
|
||||
channel::{Channel, ChannelDetails, ChannelList},
|
||||
editor::{Editor, Insert},
|
||||
fs::{FakeFs, Fs as _},
|
||||
language::LanguageRegistry,
|
||||
rpc::Client,
|
||||
settings, test,
|
||||
rpc::{self, Client},
|
||||
settings,
|
||||
user::UserStore,
|
||||
worktree::Worktree,
|
||||
};
|
||||
|
@ -1469,7 +1474,7 @@ mod tests {
|
|||
.await;
|
||||
|
||||
// Drop client B's connection and ensure client A observes client B leaving the worktree.
|
||||
client_b.disconnect().await.unwrap();
|
||||
client_b.disconnect(&cx_b.to_async()).await.unwrap();
|
||||
worktree_a
|
||||
.condition(&cx_a, |tree, _| tree.peers().len() == 0)
|
||||
.await;
|
||||
|
@ -1675,11 +1680,206 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
|
||||
cx_a.foreground().forbid_parking();
|
||||
|
||||
// Connect to a server as 2 clients.
|
||||
let mut server = TestServer::start().await;
|
||||
let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
|
||||
let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
|
||||
let mut status_b = client_b.status();
|
||||
|
||||
// Create an org that includes these 2 users.
|
||||
let db = &server.app_state.db;
|
||||
let org_id = db.create_org("Test Org", "test-org").await.unwrap();
|
||||
db.add_org_member(org_id, user_id_a, false).await.unwrap();
|
||||
db.add_org_member(org_id, user_id_b, false).await.unwrap();
|
||||
|
||||
// Create a channel that includes all the users.
|
||||
let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
|
||||
db.add_channel_member(channel_id, user_id_a, false)
|
||||
.await
|
||||
.unwrap();
|
||||
db.add_channel_member(channel_id, user_id_b, false)
|
||||
.await
|
||||
.unwrap();
|
||||
db.create_channel_message(
|
||||
channel_id,
|
||||
user_id_b,
|
||||
"hello A, it's B.",
|
||||
OffsetDateTime::now_utc(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let user_store_a = Arc::new(UserStore::new(client_a.clone()));
|
||||
let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
|
||||
channels_a
|
||||
.condition(&mut cx_a, |list, _| list.available_channels().is_some())
|
||||
.await;
|
||||
|
||||
channels_a.read_with(&cx_a, |list, _| {
|
||||
assert_eq!(
|
||||
list.available_channels().unwrap(),
|
||||
&[ChannelDetails {
|
||||
id: channel_id.to_proto(),
|
||||
name: "test-channel".to_string()
|
||||
}]
|
||||
)
|
||||
});
|
||||
let channel_a = channels_a.update(&mut cx_a, |this, cx| {
|
||||
this.get_channel(channel_id.to_proto(), cx).unwrap()
|
||||
});
|
||||
channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
|
||||
channel_a
|
||||
.condition(&cx_a, |channel, _| {
|
||||
channel_messages(channel)
|
||||
== [("user_b".to_string(), "hello A, it's B.".to_string())]
|
||||
})
|
||||
.await;
|
||||
|
||||
let user_store_b = Arc::new(UserStore::new(client_b.clone()));
|
||||
let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
|
||||
channels_b
|
||||
.condition(&mut cx_b, |list, _| list.available_channels().is_some())
|
||||
.await;
|
||||
channels_b.read_with(&cx_b, |list, _| {
|
||||
assert_eq!(
|
||||
list.available_channels().unwrap(),
|
||||
&[ChannelDetails {
|
||||
id: channel_id.to_proto(),
|
||||
name: "test-channel".to_string()
|
||||
}]
|
||||
)
|
||||
});
|
||||
|
||||
let channel_b = channels_b.update(&mut cx_b, |this, cx| {
|
||||
this.get_channel(channel_id.to_proto(), cx).unwrap()
|
||||
});
|
||||
channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
|
||||
channel_b
|
||||
.condition(&cx_b, |channel, _| {
|
||||
channel_messages(channel)
|
||||
== [("user_b".to_string(), "hello A, it's B.".to_string())]
|
||||
})
|
||||
.await;
|
||||
|
||||
// Disconnect client B, ensuring we can still access its cached channel data.
|
||||
server.forbid_connections();
|
||||
server.disconnect_client(user_id_b);
|
||||
while !matches!(
|
||||
status_b.recv().await,
|
||||
Some(rpc::Status::ReconnectionError { .. })
|
||||
) {}
|
||||
|
||||
channels_b.read_with(&cx_b, |channels, _| {
|
||||
assert_eq!(
|
||||
channels.available_channels().unwrap(),
|
||||
[ChannelDetails {
|
||||
id: channel_id.to_proto(),
|
||||
name: "test-channel".to_string()
|
||||
}]
|
||||
)
|
||||
});
|
||||
channel_b.read_with(&cx_b, |channel, _| {
|
||||
assert_eq!(
|
||||
channel_messages(channel),
|
||||
[("user_b".to_string(), "hello A, it's B.".to_string())]
|
||||
)
|
||||
});
|
||||
|
||||
// Send a message from client A while B is disconnected.
|
||||
channel_a
|
||||
.update(&mut cx_a, |channel, cx| {
|
||||
channel
|
||||
.send_message("oh, hi B.".to_string(), cx)
|
||||
.unwrap()
|
||||
.detach();
|
||||
let task = channel.send_message("sup".to_string(), cx).unwrap();
|
||||
assert_eq!(
|
||||
channel
|
||||
.pending_messages()
|
||||
.iter()
|
||||
.map(|m| &m.body)
|
||||
.collect::<Vec<_>>(),
|
||||
&["oh, hi B.", "sup"]
|
||||
);
|
||||
task
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Give client B a chance to reconnect.
|
||||
server.allow_connections();
|
||||
cx_b.foreground().advance_clock(Duration::from_secs(10));
|
||||
|
||||
// Verify that B sees the new messages upon reconnection.
|
||||
channel_b
|
||||
.condition(&cx_b, |channel, _| {
|
||||
channel_messages(channel)
|
||||
== [
|
||||
("user_b".to_string(), "hello A, it's B.".to_string()),
|
||||
("user_a".to_string(), "oh, hi B.".to_string()),
|
||||
("user_a".to_string(), "sup".to_string()),
|
||||
]
|
||||
})
|
||||
.await;
|
||||
|
||||
// Ensure client A and B can communicate normally after reconnection.
|
||||
channel_a
|
||||
.update(&mut cx_a, |channel, cx| {
|
||||
channel.send_message("you online?".to_string(), cx).unwrap()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
channel_b
|
||||
.condition(&cx_b, |channel, _| {
|
||||
channel_messages(channel)
|
||||
== [
|
||||
("user_b".to_string(), "hello A, it's B.".to_string()),
|
||||
("user_a".to_string(), "oh, hi B.".to_string()),
|
||||
("user_a".to_string(), "sup".to_string()),
|
||||
("user_a".to_string(), "you online?".to_string()),
|
||||
]
|
||||
})
|
||||
.await;
|
||||
|
||||
channel_b
|
||||
.update(&mut cx_b, |channel, cx| {
|
||||
channel.send_message("yep".to_string(), cx).unwrap()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
channel_a
|
||||
.condition(&cx_a, |channel, _| {
|
||||
channel_messages(channel)
|
||||
== [
|
||||
("user_b".to_string(), "hello A, it's B.".to_string()),
|
||||
("user_a".to_string(), "oh, hi B.".to_string()),
|
||||
("user_a".to_string(), "sup".to_string()),
|
||||
("user_a".to_string(), "you online?".to_string()),
|
||||
("user_b".to_string(), "yep".to_string()),
|
||||
]
|
||||
})
|
||||
.await;
|
||||
|
||||
fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
|
||||
channel
|
||||
.messages()
|
||||
.cursor::<(), ()>()
|
||||
.map(|m| (m.sender.github_login.clone(), m.body.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
struct TestServer {
|
||||
peer: Arc<Peer>,
|
||||
app_state: Arc<AppState>,
|
||||
server: Arc<Server>,
|
||||
notifications: mpsc::Receiver<()>,
|
||||
connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
|
||||
forbid_connections: Arc<AtomicBool>,
|
||||
_test_db: TestDb,
|
||||
}
|
||||
|
||||
|
@ -1695,6 +1895,8 @@ mod tests {
|
|||
app_state,
|
||||
server,
|
||||
notifications: notifications.1,
|
||||
connection_killers: Default::default(),
|
||||
forbid_connections: Default::default(),
|
||||
_test_db: test_db,
|
||||
}
|
||||
}
|
||||
|
@ -1704,20 +1906,67 @@ mod tests {
|
|||
cx: &mut TestAppContext,
|
||||
name: &str,
|
||||
) -> (UserId, Arc<Client>) {
|
||||
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
||||
let client = Client::new();
|
||||
let (client_conn, server_conn) = test::Channel::bidirectional();
|
||||
let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
||||
let client_name = name.to_string();
|
||||
let mut client = Client::new();
|
||||
let server = self.server.clone();
|
||||
let connection_killers = self.connection_killers.clone();
|
||||
let forbid_connections = self.forbid_connections.clone();
|
||||
Arc::get_mut(&mut client)
|
||||
.unwrap()
|
||||
.set_login_and_connect_callbacks(
|
||||
move |cx| {
|
||||
cx.spawn(|_| async move {
|
||||
let access_token = "the-token".to_string();
|
||||
Ok((client_user_id.0 as u64, access_token))
|
||||
})
|
||||
},
|
||||
move |user_id, access_token, cx| {
|
||||
assert_eq!(user_id, client_user_id.0 as u64);
|
||||
assert_eq!(access_token, "the-token");
|
||||
|
||||
let server = server.clone();
|
||||
let connection_killers = connection_killers.clone();
|
||||
let forbid_connections = forbid_connections.clone();
|
||||
let client_name = client_name.clone();
|
||||
cx.spawn(move |cx| async move {
|
||||
if forbid_connections.load(SeqCst) {
|
||||
Err(anyhow!("server is forbidding connections"))
|
||||
} else {
|
||||
let (client_conn, server_conn, kill_conn) = Conn::in_memory();
|
||||
connection_killers.lock().insert(client_user_id, kill_conn);
|
||||
cx.background()
|
||||
.spawn(
|
||||
self.server
|
||||
.handle_connection(server_conn, name.to_string(), user_id),
|
||||
)
|
||||
.spawn(server.handle_connection(
|
||||
server_conn,
|
||||
client_name,
|
||||
client_user_id,
|
||||
))
|
||||
.detach();
|
||||
Ok(client_conn)
|
||||
}
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
client
|
||||
.add_connection(user_id.to_proto(), client_conn, &cx.to_async())
|
||||
.authenticate_and_connect(&cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
(user_id, client)
|
||||
(client_user_id, client)
|
||||
}
|
||||
|
||||
fn disconnect_client(&self, user_id: UserId) {
|
||||
if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
|
||||
let _ = kill_conn.try_send(Some(()));
|
||||
}
|
||||
}
|
||||
|
||||
fn forbid_connections(&self) {
|
||||
self.forbid_connections.store(true, SeqCst);
|
||||
}
|
||||
|
||||
fn allow_connections(&self) {
|
||||
self.forbid_connections.store(false, SeqCst);
|
||||
}
|
||||
|
||||
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
|
||||
|
|
|
@ -11,6 +11,7 @@ use gpui::{
|
|||
use postage::prelude::Stream;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
mem,
|
||||
ops::Range,
|
||||
sync::Arc,
|
||||
};
|
||||
|
@ -71,7 +72,7 @@ pub enum ChannelListEvent {}
|
|||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum ChannelEvent {
|
||||
MessagesAdded {
|
||||
MessagesUpdated {
|
||||
old_range: Range<usize>,
|
||||
new_count: usize,
|
||||
},
|
||||
|
@ -87,36 +88,47 @@ impl ChannelList {
|
|||
rpc: Arc<rpc::Client>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let _task = cx.spawn(|this, mut cx| {
|
||||
let _task = cx.spawn_weak(|this, mut cx| {
|
||||
let rpc = rpc.clone();
|
||||
async move {
|
||||
let mut user_id = rpc.user_id();
|
||||
loop {
|
||||
let available_channels = if user_id.recv().await.unwrap().is_some() {
|
||||
Some(
|
||||
rpc.request(proto::GetChannels {})
|
||||
let mut status = rpc.status();
|
||||
while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
|
||||
match status {
|
||||
rpc::Status::Connected { .. } => {
|
||||
let response = rpc
|
||||
.request(proto::GetChannels {})
|
||||
.await
|
||||
.context("failed to fetch available channels")?
|
||||
.channels
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
.context("failed to fetch available channels")?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if available_channels.is_none() {
|
||||
if this.available_channels.is_none() {
|
||||
return;
|
||||
this.available_channels =
|
||||
Some(response.channels.into_iter().map(Into::into).collect());
|
||||
|
||||
let mut to_remove = Vec::new();
|
||||
for (channel_id, channel) in &this.channels {
|
||||
if let Some(channel) = channel.upgrade(cx) {
|
||||
channel.update(cx, |channel, cx| channel.rejoin(cx))
|
||||
} else {
|
||||
to_remove.push(*channel_id);
|
||||
}
|
||||
this.channels.clear();
|
||||
}
|
||||
this.available_channels = available_channels;
|
||||
|
||||
for channel_id in to_remove {
|
||||
this.channels.remove(&channel_id);
|
||||
}
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
rpc::Status::Disconnected { .. } => {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.available_channels = None;
|
||||
this.channels.clear();
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
|
@ -285,6 +297,43 @@ impl Channel {
|
|||
false
|
||||
}
|
||||
|
||||
pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let user_store = self.user_store.clone();
|
||||
let rpc = self.rpc.clone();
|
||||
let channel_id = self.details.id;
|
||||
cx.spawn(|channel, mut cx| {
|
||||
async move {
|
||||
let response = rpc.request(proto::JoinChannel { channel_id }).await?;
|
||||
let messages = messages_from_proto(response.messages, &user_store).await?;
|
||||
let loaded_all_messages = response.done;
|
||||
|
||||
channel.update(&mut cx, |channel, cx| {
|
||||
if let Some((first_new_message, last_old_message)) =
|
||||
messages.first().zip(channel.messages.last())
|
||||
{
|
||||
if first_new_message.id > last_old_message.id {
|
||||
let old_messages = mem::take(&mut channel.messages);
|
||||
cx.emit(ChannelEvent::MessagesUpdated {
|
||||
old_range: 0..old_messages.summary().count,
|
||||
new_count: 0,
|
||||
});
|
||||
channel.loaded_all_messages = loaded_all_messages;
|
||||
}
|
||||
}
|
||||
|
||||
channel.insert_messages(messages, cx);
|
||||
if loaded_all_messages {
|
||||
channel.loaded_all_messages = loaded_all_messages;
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.log_err()
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub fn message_count(&self) -> usize {
|
||||
self.messages.summary().count
|
||||
}
|
||||
|
@ -350,7 +399,7 @@ impl Channel {
|
|||
drop(old_cursor);
|
||||
self.messages = new_messages;
|
||||
|
||||
cx.emit(ChannelEvent::MessagesAdded {
|
||||
cx.emit(ChannelEvent::MessagesUpdated {
|
||||
old_range: start_ix..end_ix,
|
||||
new_count,
|
||||
});
|
||||
|
@ -446,22 +495,21 @@ impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::test::FakeServer;
|
||||
use gpui::TestAppContext;
|
||||
use postage::mpsc::Receiver;
|
||||
use zrpc::{test::Channel, ConnectionId, Peer, Receipt};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_channel_messages(mut cx: TestAppContext) {
|
||||
let user_id = 5;
|
||||
let client = Client::new();
|
||||
let mut server = FakeServer::for_client(user_id, &client, &cx).await;
|
||||
let mut client = Client::new();
|
||||
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||
let user_store = Arc::new(UserStore::new(client.clone()));
|
||||
|
||||
let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
|
||||
channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
|
||||
|
||||
// Get the available channels.
|
||||
let get_channels = server.receive::<proto::GetChannels>().await;
|
||||
let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
|
||||
server
|
||||
.respond(
|
||||
get_channels.receipt(),
|
||||
|
@ -492,7 +540,7 @@ mod tests {
|
|||
})
|
||||
.unwrap();
|
||||
channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
|
||||
let join_channel = server.receive::<proto::JoinChannel>().await;
|
||||
let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
|
||||
server
|
||||
.respond(
|
||||
join_channel.receipt(),
|
||||
|
@ -517,7 +565,7 @@ mod tests {
|
|||
.await;
|
||||
|
||||
// Client requests all users for the received messages
|
||||
let mut get_users = server.receive::<proto::GetUsers>().await;
|
||||
let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
|
||||
get_users.payload.user_ids.sort();
|
||||
assert_eq!(get_users.payload.user_ids, vec![5, 6]);
|
||||
server
|
||||
|
@ -542,7 +590,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
channel.next_event(&cx).await,
|
||||
ChannelEvent::MessagesAdded {
|
||||
ChannelEvent::MessagesUpdated {
|
||||
old_range: 0..0,
|
||||
new_count: 2,
|
||||
}
|
||||
|
@ -574,7 +622,7 @@ mod tests {
|
|||
.await;
|
||||
|
||||
// Client requests user for message since they haven't seen them yet
|
||||
let get_users = server.receive::<proto::GetUsers>().await;
|
||||
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
|
||||
assert_eq!(get_users.payload.user_ids, vec![7]);
|
||||
server
|
||||
.respond(
|
||||
|
@ -591,7 +639,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
channel.next_event(&cx).await,
|
||||
ChannelEvent::MessagesAdded {
|
||||
ChannelEvent::MessagesUpdated {
|
||||
old_range: 2..2,
|
||||
new_count: 1,
|
||||
}
|
||||
|
@ -610,7 +658,7 @@ mod tests {
|
|||
channel.update(&mut cx, |channel, cx| {
|
||||
assert!(channel.load_more_messages(cx));
|
||||
});
|
||||
let get_messages = server.receive::<proto::GetChannelMessages>().await;
|
||||
let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
|
||||
assert_eq!(get_messages.payload.channel_id, 5);
|
||||
assert_eq!(get_messages.payload.before_message_id, 10);
|
||||
server
|
||||
|
@ -638,7 +686,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
channel.next_event(&cx).await,
|
||||
ChannelEvent::MessagesAdded {
|
||||
ChannelEvent::MessagesUpdated {
|
||||
old_range: 0..0,
|
||||
new_count: 2,
|
||||
}
|
||||
|
@ -656,53 +704,4 @@ mod tests {
|
|||
);
|
||||
});
|
||||
}
|
||||
|
||||
struct FakeServer {
|
||||
peer: Arc<Peer>,
|
||||
incoming: Receiver<Box<dyn proto::AnyTypedEnvelope>>,
|
||||
connection_id: ConnectionId,
|
||||
}
|
||||
|
||||
impl FakeServer {
|
||||
async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
|
||||
let (client_conn, server_conn) = Channel::bidirectional();
|
||||
let peer = Peer::new();
|
||||
let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
|
||||
cx.background().spawn(io).detach();
|
||||
|
||||
client
|
||||
.add_connection(user_id, client_conn, &cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
peer,
|
||||
incoming,
|
||||
connection_id,
|
||||
}
|
||||
}
|
||||
|
||||
async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
|
||||
self.peer.send(self.connection_id, message).await.unwrap();
|
||||
}
|
||||
|
||||
async fn receive<M: proto::EnvelopedMessage>(&mut self) -> TypedEnvelope<M> {
|
||||
*self
|
||||
.incoming
|
||||
.recv()
|
||||
.await
|
||||
.unwrap()
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<M>>()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn respond<T: proto::RequestMessage>(
|
||||
&self,
|
||||
receipt: Receipt<T>,
|
||||
response: T::Response,
|
||||
) {
|
||||
self.peer.respond(receipt, response).await.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::sync::Arc;
|
|||
use crate::{
|
||||
channel::{Channel, ChannelEvent, ChannelList, ChannelMessage},
|
||||
editor::Editor,
|
||||
rpc::Client,
|
||||
rpc::{self, Client},
|
||||
theme,
|
||||
util::{ResultExt, TryFutureExt},
|
||||
Settings,
|
||||
|
@ -14,10 +14,10 @@ use gpui::{
|
|||
keymap::Binding,
|
||||
platform::CursorStyle,
|
||||
views::{ItemType, Select, SelectStyle},
|
||||
AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, View,
|
||||
AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, Task, View,
|
||||
ViewContext, ViewHandle,
|
||||
};
|
||||
use postage::watch;
|
||||
use postage::{prelude::Stream, watch};
|
||||
use time::{OffsetDateTime, UtcOffset};
|
||||
|
||||
const MESSAGE_LOADING_THRESHOLD: usize = 50;
|
||||
|
@ -31,6 +31,7 @@ pub struct ChatPanel {
|
|||
channel_select: ViewHandle<Select>,
|
||||
settings: watch::Receiver<Settings>,
|
||||
local_timezone: UtcOffset,
|
||||
_observe_status: Task<()>,
|
||||
}
|
||||
|
||||
pub enum Event {}
|
||||
|
@ -98,6 +99,14 @@ impl ChatPanel {
|
|||
cx.dispatch_action(LoadMoreMessages);
|
||||
}
|
||||
});
|
||||
let _observe_status = cx.spawn(|this, mut cx| {
|
||||
let mut status = rpc.status();
|
||||
async move {
|
||||
while let Some(_) = status.recv().await {
|
||||
this.update(&mut cx, |_, cx| cx.notify());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut this = Self {
|
||||
rpc,
|
||||
|
@ -108,6 +117,7 @@ impl ChatPanel {
|
|||
channel_select,
|
||||
settings,
|
||||
local_timezone: cx.platform().local_timezone(),
|
||||
_observe_status,
|
||||
};
|
||||
|
||||
this.init_active_channel(cx);
|
||||
|
@ -153,6 +163,7 @@ impl ChatPanel {
|
|||
if let Some(active_channel) = active_channel {
|
||||
self.set_active_channel(active_channel, cx);
|
||||
} else {
|
||||
self.message_list.reset(0);
|
||||
self.active_channel = None;
|
||||
}
|
||||
|
||||
|
@ -183,7 +194,7 @@ impl ChatPanel {
|
|||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
match event {
|
||||
ChannelEvent::MessagesAdded {
|
||||
ChannelEvent::MessagesUpdated {
|
||||
old_range,
|
||||
new_count,
|
||||
} => {
|
||||
|
@ -357,10 +368,6 @@ impl ChatPanel {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn is_signed_in(&self) -> bool {
|
||||
self.rpc.user_id().borrow().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl Entity for ChatPanel {
|
||||
|
@ -374,10 +381,9 @@ impl View for ChatPanel {
|
|||
|
||||
fn render(&mut self, cx: &mut RenderContext<Self>) -> ElementBox {
|
||||
let theme = &self.settings.borrow().theme;
|
||||
let element = if self.is_signed_in() {
|
||||
self.render_channel()
|
||||
} else {
|
||||
self.render_sign_in_prompt(cx)
|
||||
let element = match *self.rpc.status().borrow() {
|
||||
rpc::Status::Connected { .. } => self.render_channel(),
|
||||
_ => self.render_sign_in_prompt(cx),
|
||||
};
|
||||
ConstrainedBox::new(
|
||||
Container::new(element)
|
||||
|
@ -389,7 +395,7 @@ impl View for ChatPanel {
|
|||
}
|
||||
|
||||
fn on_focus(&mut self, cx: &mut ViewContext<Self>) {
|
||||
if self.is_signed_in() {
|
||||
if matches!(*self.rpc.status().borrow(), rpc::Status::Connected { .. }) {
|
||||
cx.focus(&self.input_editor);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2695,14 +2695,7 @@ impl<'a> Into<proto::operation::Edit> for &'a EditOperation {
|
|||
impl<'a> Into<proto::Anchor> for &'a Anchor {
|
||||
fn into(self) -> proto::Anchor {
|
||||
proto::Anchor {
|
||||
version: self
|
||||
.version
|
||||
.iter()
|
||||
.map(|entry| proto::VectorClockEntry {
|
||||
replica_id: entry.replica_id as u32,
|
||||
timestamp: entry.value,
|
||||
})
|
||||
.collect(),
|
||||
version: (&self.version).into(),
|
||||
offset: self.offset as u64,
|
||||
bias: match self.bias {
|
||||
Bias::Left => proto::anchor::Bias::Left as i32,
|
||||
|
|
398
zed/src/rpc.rs
398
zed/src/rpc.rs
|
@ -1,24 +1,24 @@
|
|||
use crate::util::ResultExt;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_tungstenite::tungstenite::http::Request;
|
||||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::RwLock;
|
||||
use postage::prelude::Stream;
|
||||
use postage::sink::Sink;
|
||||
use postage::watch;
|
||||
use std::any::TypeId;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Weak;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{convert::TryFrom, future::Future, sync::Arc};
|
||||
use postage::{prelude::Stream, watch};
|
||||
use rand::prelude::*;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::HashMap,
|
||||
convert::TryFrom,
|
||||
future::Future,
|
||||
sync::{Arc, Weak},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use surf::Url;
|
||||
use zrpc::proto::{AnyTypedEnvelope, EntityMessage};
|
||||
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
|
||||
use zrpc::{
|
||||
proto::{EnvelopedMessage, RequestMessage},
|
||||
Peer, Receipt,
|
||||
proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
|
||||
Conn, Peer, Receipt,
|
||||
};
|
||||
|
||||
lazy_static! {
|
||||
|
@ -29,25 +29,55 @@ lazy_static! {
|
|||
pub struct Client {
|
||||
peer: Arc<Peer>,
|
||||
state: RwLock<ClientState>,
|
||||
auth_callback: Option<
|
||||
Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
|
||||
>,
|
||||
connect_callback: Option<
|
||||
Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
|
||||
>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum Status {
|
||||
Disconnected,
|
||||
Authenticating,
|
||||
Connecting {
|
||||
user_id: u64,
|
||||
},
|
||||
ConnectionError,
|
||||
Connected {
|
||||
connection_id: ConnectionId,
|
||||
user_id: u64,
|
||||
},
|
||||
ConnectionLost,
|
||||
Reauthenticating,
|
||||
Reconnecting {
|
||||
user_id: u64,
|
||||
},
|
||||
ReconnectionError {
|
||||
next_reconnection: Instant,
|
||||
},
|
||||
}
|
||||
|
||||
struct ClientState {
|
||||
connection_id: Option<ConnectionId>,
|
||||
user_id: (watch::Sender<Option<u64>>, watch::Receiver<Option<u64>>),
|
||||
status: (watch::Sender<Status>, watch::Receiver<Status>),
|
||||
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
|
||||
model_handlers: HashMap<
|
||||
(TypeId, u64),
|
||||
Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
|
||||
>,
|
||||
_maintain_connection: Option<Task<()>>,
|
||||
heartbeat_interval: Duration,
|
||||
}
|
||||
|
||||
impl Default for ClientState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
connection_id: Default::default(),
|
||||
user_id: watch::channel(),
|
||||
status: watch::channel_with(Status::Disconnected),
|
||||
entity_id_extractors: Default::default(),
|
||||
model_handlers: Default::default(),
|
||||
_maintain_connection: None,
|
||||
heartbeat_interval: Duration::from_secs(5),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -77,11 +107,71 @@ impl Client {
|
|||
Arc::new(Self {
|
||||
peer: Peer::new(),
|
||||
state: Default::default(),
|
||||
auth_callback: None,
|
||||
connect_callback: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn user_id(&self) -> watch::Receiver<Option<u64>> {
|
||||
self.state.read().user_id.1.clone()
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn set_login_and_connect_callbacks<Login, Connect>(
|
||||
&mut self,
|
||||
login: Login,
|
||||
connect: Connect,
|
||||
) where
|
||||
Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
|
||||
Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
|
||||
{
|
||||
self.auth_callback = Some(Box::new(login));
|
||||
self.connect_callback = Some(Box::new(connect));
|
||||
}
|
||||
|
||||
pub fn status(&self) -> watch::Receiver<Status> {
|
||||
self.state.read().status.1.clone()
|
||||
}
|
||||
|
||||
fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
|
||||
let mut state = self.state.write();
|
||||
*state.status.0.borrow_mut() = status;
|
||||
|
||||
match status {
|
||||
Status::Connected { .. } => {
|
||||
let heartbeat_interval = state.heartbeat_interval;
|
||||
let this = self.clone();
|
||||
let foreground = cx.foreground();
|
||||
state._maintain_connection = Some(cx.foreground().spawn(async move {
|
||||
loop {
|
||||
foreground.timer(heartbeat_interval).await;
|
||||
this.request(proto::Ping {}).await.unwrap();
|
||||
}
|
||||
}));
|
||||
}
|
||||
Status::ConnectionLost => {
|
||||
let this = self.clone();
|
||||
let foreground = cx.foreground();
|
||||
let heartbeat_interval = state.heartbeat_interval;
|
||||
state._maintain_connection = Some(cx.spawn(|cx| async move {
|
||||
let mut rng = StdRng::from_entropy();
|
||||
let mut delay = Duration::from_millis(100);
|
||||
while let Err(error) = this.authenticate_and_connect(&cx).await {
|
||||
log::error!("failed to connect {}", error);
|
||||
this.set_status(
|
||||
Status::ReconnectionError {
|
||||
next_reconnection: Instant::now() + delay,
|
||||
},
|
||||
&cx,
|
||||
);
|
||||
foreground.timer(delay).await;
|
||||
delay = delay
|
||||
.mul_f32(rng.gen_range(1.0..=2.0))
|
||||
.min(heartbeat_interval);
|
||||
}
|
||||
}));
|
||||
}
|
||||
Status::Disconnected => {
|
||||
state._maintain_connection.take();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subscribe_from_model<T, M, F>(
|
||||
|
@ -141,56 +231,57 @@ impl Client {
|
|||
self: &Arc<Self>,
|
||||
cx: &AsyncAppContext,
|
||||
) -> anyhow::Result<()> {
|
||||
if self.state.read().connection_id.is_some() {
|
||||
return Ok(());
|
||||
let was_disconnected = match *self.status().borrow() {
|
||||
Status::Disconnected => true,
|
||||
Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
|
||||
false
|
||||
}
|
||||
|
||||
let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
|
||||
let user_id = user_id.parse::<u64>()?;
|
||||
let request =
|
||||
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
|
||||
|
||||
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
|
||||
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
self.add_connection(user_id, stream, cx).await?;
|
||||
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
|
||||
let (stream, _) = async_tungstenite::client_async(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
self.add_connection(user_id, stream, cx).await?;
|
||||
} else {
|
||||
return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
|
||||
Status::Connected { .. }
|
||||
| Status::Connecting { .. }
|
||||
| Status::Reconnecting { .. }
|
||||
| Status::Authenticating
|
||||
| Status::Reauthenticating => return Ok(()),
|
||||
};
|
||||
|
||||
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
||||
Ok(())
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx)
|
||||
}
|
||||
|
||||
pub async fn add_connection<Conn>(
|
||||
self: &Arc<Self>,
|
||||
user_id: u64,
|
||||
conn: Conn,
|
||||
cx: &AsyncAppContext,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
Conn: 'static
|
||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
||||
+ Unpin
|
||||
+ Send,
|
||||
{
|
||||
let (user_id, access_token) = match self.authenticate(&cx).await {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Connecting { user_id }, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reconnecting { user_id }, cx);
|
||||
}
|
||||
match self.connect(user_id, &access_token, cx).await {
|
||||
Ok(conn) => {
|
||||
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
||||
self.set_connection(user_id, conn, cx).await;
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
|
||||
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
|
||||
{
|
||||
cx.foreground()
|
||||
.spawn({
|
||||
let mut cx = cx.clone();
|
||||
let this = self.clone();
|
||||
cx.foreground()
|
||||
.spawn(async move {
|
||||
async move {
|
||||
while let Some(message) = incoming.recv().await {
|
||||
let mut state = this.state.write();
|
||||
if let Some(extract_entity_id) =
|
||||
|
@ -215,27 +306,90 @@ impl Client {
|
|||
log::info!("unhandled message {}", message.payload_type_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
cx.background()
|
||||
|
||||
self.set_status(
|
||||
Status::Connected {
|
||||
connection_id,
|
||||
user_id,
|
||||
},
|
||||
cx,
|
||||
);
|
||||
|
||||
let handle_io = cx.background().spawn(handle_io);
|
||||
let this = self.clone();
|
||||
let cx = cx.clone();
|
||||
cx.foreground()
|
||||
.spawn(async move {
|
||||
if let Err(error) = handle_io.await {
|
||||
log::error!("connection error: {:?}", error);
|
||||
match handle_io.await {
|
||||
Ok(()) => this.set_status(Status::Disconnected, &cx),
|
||||
Err(err) => {
|
||||
log::error!("connection error: {:?}", err);
|
||||
this.set_status(Status::ConnectionLost, &cx);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
let mut state = self.state.write();
|
||||
state.connection_id = Some(connection_id);
|
||||
state.user_id.0.send(Some(user_id)).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn login(
|
||||
platform: Arc<dyn gpui::Platform>,
|
||||
executor: &Arc<gpui::executor::Background>,
|
||||
) -> Task<Result<(String, String)>> {
|
||||
let executor = executor.clone();
|
||||
fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
|
||||
if let Some(callback) = self.auth_callback.as_ref() {
|
||||
callback(cx)
|
||||
} else {
|
||||
self.authenticate_with_browser(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn connect(
|
||||
self: &Arc<Self>,
|
||||
user_id: u64,
|
||||
access_token: &str,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Task<Result<Conn>> {
|
||||
if let Some(callback) = self.connect_callback.as_ref() {
|
||||
callback(user_id, access_token, cx)
|
||||
} else {
|
||||
self.connect_with_websocket(user_id, access_token, cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn connect_with_websocket(
|
||||
self: &Arc<Self>,
|
||||
user_id: u64,
|
||||
access_token: &str,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Task<Result<Conn>> {
|
||||
let request =
|
||||
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
|
||||
cx.background().spawn(async move {
|
||||
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
|
||||
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
Ok(Conn::new(stream))
|
||||
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
|
||||
let stream = smol::net::TcpStream::connect(host).await?;
|
||||
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
|
||||
let (stream, _) = async_tungstenite::client_async(request, stream)
|
||||
.await
|
||||
.context("websocket handshake")?;
|
||||
Ok(Conn::new(stream))
|
||||
} else {
|
||||
Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authenticate_with_browser(
|
||||
self: &Arc<Self>,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Task<Result<(u64, String)>> {
|
||||
let platform = cx.platform();
|
||||
let executor = cx.background();
|
||||
executor.clone().spawn(async move {
|
||||
if let Some((user_id, access_token)) = platform
|
||||
.read_credentials(&ZED_SERVER_URL)
|
||||
|
@ -243,7 +397,7 @@ impl Client {
|
|||
.flatten()
|
||||
{
|
||||
log::info!("already signed in. user_id: {}", user_id);
|
||||
return Ok((user_id, String::from_utf8(access_token).unwrap()));
|
||||
return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
|
||||
}
|
||||
|
||||
// Generate a pair of asymmetric encryption keys. The public key will be used by the
|
||||
|
@ -309,21 +463,23 @@ impl Client {
|
|||
platform
|
||||
.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
|
||||
.log_err();
|
||||
Ok((user_id.to_string(), access_token))
|
||||
Ok((user_id.parse()?, access_token))
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn disconnect(&self) -> Result<()> {
|
||||
pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
|
||||
let conn_id = self.connection_id()?;
|
||||
self.peer.disconnect(conn_id).await;
|
||||
self.set_status(Status::Disconnected, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn connection_id(&self) -> Result<ConnectionId> {
|
||||
self.state
|
||||
.read()
|
||||
.connection_id
|
||||
.ok_or_else(|| anyhow!("not connected"))
|
||||
if let Status::Connected { connection_id, .. } = *self.status().borrow() {
|
||||
Ok(connection_id)
|
||||
} else {
|
||||
Err(anyhow!("not connected"))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
|
||||
|
@ -343,35 +499,6 @@ impl Client {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
|
||||
type Output: 'a + Future<Output = anyhow::Result<()>>;
|
||||
|
||||
fn handle(
|
||||
&self,
|
||||
message: TypedEnvelope<M>,
|
||||
rpc: &'a Client,
|
||||
cx: &'a mut gpui::AsyncAppContext,
|
||||
) -> Self::Output;
|
||||
}
|
||||
|
||||
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
|
||||
where
|
||||
M: proto::EnvelopedMessage,
|
||||
F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
|
||||
Fut: 'a + Future<Output = anyhow::Result<()>>,
|
||||
{
|
||||
type Output = Fut;
|
||||
|
||||
fn handle(
|
||||
&self,
|
||||
message: TypedEnvelope<M>,
|
||||
rpc: &'a Client,
|
||||
cx: &'a mut gpui::AsyncAppContext,
|
||||
) -> Self::Output {
|
||||
(self)(message, rpc, cx)
|
||||
}
|
||||
}
|
||||
|
||||
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
|
||||
|
||||
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
|
||||
|
@ -396,8 +523,56 @@ const LOGIN_RESPONSE: &'static str = "
|
|||
</html>
|
||||
";
|
||||
|
||||
#[test]
|
||||
fn test_encode_and_decode_worktree_url() {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::test::FakeServer;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_heartbeat(cx: TestAppContext) {
|
||||
cx.foreground().forbid_parking();
|
||||
|
||||
let user_id = 5;
|
||||
let mut client = Client::new();
|
||||
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||
|
||||
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||
let ping = server.receive::<proto::Ping>().await.unwrap();
|
||||
server.respond(ping.receipt(), proto::Ack {}).await;
|
||||
|
||||
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||
let ping = server.receive::<proto::Ping>().await.unwrap();
|
||||
server.respond(ping.receipt(), proto::Ack {}).await;
|
||||
|
||||
client.disconnect(&cx.to_async()).await.unwrap();
|
||||
assert!(server.receive::<proto::Ping>().await.is_err());
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_reconnection(cx: TestAppContext) {
|
||||
cx.foreground().forbid_parking();
|
||||
|
||||
let user_id = 5;
|
||||
let mut client = Client::new();
|
||||
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||
let mut status = client.status();
|
||||
assert!(matches!(
|
||||
status.recv().await,
|
||||
Some(Status::Connected { .. })
|
||||
));
|
||||
|
||||
server.forbid_connections();
|
||||
server.disconnect().await;
|
||||
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
|
||||
|
||||
server.allow_connections();
|
||||
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_and_decode_worktree_url() {
|
||||
let url = encode_worktree_url(5, "deadbeef");
|
||||
assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
|
||||
assert_eq!(
|
||||
|
@ -405,4 +580,5 @@ fn test_encode_and_decode_worktree_url() {
|
|||
Some((5, "deadbeef".to_string()))
|
||||
);
|
||||
assert_eq!(decode_worktree_url("not://the-right-format"), None);
|
||||
}
|
||||
}
|
||||
|
|
129
zed/src/test.rs
129
zed/src/test.rs
|
@ -3,24 +3,27 @@ use crate::{
|
|||
channel::ChannelList,
|
||||
fs::RealFs,
|
||||
language::LanguageRegistry,
|
||||
rpc,
|
||||
rpc::{self, Client},
|
||||
settings::{self, ThemeRegistry},
|
||||
time::ReplicaId,
|
||||
user::UserStore,
|
||||
AppState,
|
||||
};
|
||||
use gpui::{Entity, ModelHandle, MutableAppContext};
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
|
||||
use parking_lot::Mutex;
|
||||
use postage::{mpsc, prelude::Stream as _};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
marker::PhantomData,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tempdir::TempDir;
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
pub use zrpc::test::Channel;
|
||||
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
|
@ -195,3 +198,117 @@ impl<T: Entity> Observer<T> {
|
|||
(observer, notify_rx)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeServer {
|
||||
peer: Arc<Peer>,
|
||||
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
|
||||
connection_id: Mutex<Option<ConnectionId>>,
|
||||
forbid_connections: AtomicBool,
|
||||
}
|
||||
|
||||
impl FakeServer {
|
||||
pub async fn for_client(
|
||||
client_user_id: u64,
|
||||
client: &mut Arc<Client>,
|
||||
cx: &TestAppContext,
|
||||
) -> Arc<Self> {
|
||||
let result = Arc::new(Self {
|
||||
peer: Peer::new(),
|
||||
incoming: Default::default(),
|
||||
connection_id: Default::default(),
|
||||
forbid_connections: Default::default(),
|
||||
});
|
||||
|
||||
Arc::get_mut(client)
|
||||
.unwrap()
|
||||
.set_login_and_connect_callbacks(
|
||||
move |cx| {
|
||||
cx.spawn(|_| async move {
|
||||
let access_token = "the-token".to_string();
|
||||
Ok((client_user_id, access_token))
|
||||
})
|
||||
},
|
||||
{
|
||||
let server = result.clone();
|
||||
move |user_id, access_token, cx| {
|
||||
assert_eq!(user_id, client_user_id);
|
||||
assert_eq!(access_token, "the-token");
|
||||
cx.spawn({
|
||||
let server = server.clone();
|
||||
move |cx| async move { server.connect(&cx).await }
|
||||
})
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
client
|
||||
.authenticate_and_connect(&cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn disconnect(&self) {
|
||||
self.peer.disconnect(self.connection_id()).await;
|
||||
self.connection_id.lock().take();
|
||||
self.incoming.lock().take();
|
||||
}
|
||||
|
||||
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
|
||||
if self.forbid_connections.load(SeqCst) {
|
||||
Err(anyhow!("server is forbidding connections"))
|
||||
} else {
|
||||
let (client_conn, server_conn, _) = Conn::in_memory();
|
||||
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
||||
cx.background().spawn(io).detach();
|
||||
*self.incoming.lock() = Some(incoming);
|
||||
*self.connection_id.lock() = Some(connection_id);
|
||||
Ok(client_conn)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forbid_connections(&self) {
|
||||
self.forbid_connections.store(true, SeqCst);
|
||||
}
|
||||
|
||||
pub fn allow_connections(&self) {
|
||||
self.forbid_connections.store(false, SeqCst);
|
||||
}
|
||||
|
||||
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
|
||||
self.peer.send(self.connection_id(), message).await.unwrap();
|
||||
}
|
||||
|
||||
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
|
||||
let message = self
|
||||
.incoming
|
||||
.lock()
|
||||
.as_mut()
|
||||
.expect("not connected")
|
||||
.recv()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("other half hung up"))?;
|
||||
let type_name = message.payload_type_name();
|
||||
Ok(*message
|
||||
.into_any()
|
||||
.downcast::<TypedEnvelope<M>>()
|
||||
.unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"fake server received unexpected message type: {:?}",
|
||||
type_name
|
||||
);
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn respond<T: proto::RequestMessage>(
|
||||
&self,
|
||||
receipt: Receipt<T>,
|
||||
response: T::Response,
|
||||
) {
|
||||
self.peer.respond(receipt, response).await.unwrap()
|
||||
}
|
||||
|
||||
fn connection_id(&self) -> ConnectionId {
|
||||
self.connection_id.lock().expect("not connected")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -234,6 +234,7 @@ impl Worktree {
|
|||
.into_iter()
|
||||
.map(|p| (PeerId(p.peer_id), p.replica_id as ReplicaId))
|
||||
.collect(),
|
||||
queued_operations: Default::default(),
|
||||
languages,
|
||||
_subscriptions,
|
||||
})
|
||||
|
@ -656,6 +657,7 @@ pub struct LocalWorktree {
|
|||
shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
|
||||
peers: HashMap<PeerId, ReplicaId>,
|
||||
languages: Arc<LanguageRegistry>,
|
||||
queued_operations: Vec<(u64, Operation)>,
|
||||
fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
||||
|
@ -711,6 +713,7 @@ impl LocalWorktree {
|
|||
poll_task: None,
|
||||
open_buffers: Default::default(),
|
||||
shared_buffers: Default::default(),
|
||||
queued_operations: Default::default(),
|
||||
peers: Default::default(),
|
||||
languages,
|
||||
fs,
|
||||
|
@ -1091,6 +1094,7 @@ pub struct RemoteWorktree {
|
|||
open_buffers: HashMap<usize, RemoteBuffer>,
|
||||
peers: HashMap<PeerId, ReplicaId>,
|
||||
languages: Arc<LanguageRegistry>,
|
||||
queued_operations: Vec<(u64, Operation)>,
|
||||
_subscriptions: Vec<rpc::Subscription>,
|
||||
}
|
||||
|
||||
|
@ -1550,16 +1554,23 @@ impl File {
|
|||
.map(|share| (share.rpc.clone(), share.remote_id)),
|
||||
Worktree::Remote(worktree) => Some((worktree.rpc.clone(), worktree.remote_id)),
|
||||
} {
|
||||
cx.spawn(|_, _| async move {
|
||||
cx.spawn(|worktree, mut cx| async move {
|
||||
if let Err(error) = rpc
|
||||
.send(proto::UpdateBuffer {
|
||||
.request(proto::UpdateBuffer {
|
||||
worktree_id: remote_id,
|
||||
buffer_id,
|
||||
operations: Some(operation).iter().map(Into::into).collect(),
|
||||
operations: vec![(&operation).into()],
|
||||
})
|
||||
.await
|
||||
{
|
||||
worktree.update(&mut cx, |worktree, _| {
|
||||
log::error!("error sending buffer operation: {}", error);
|
||||
match worktree {
|
||||
Worktree::Local(t) => &mut t.queued_operations,
|
||||
Worktree::Remote(t) => &mut t.queued_operations,
|
||||
}
|
||||
.push((buffer_id, operation));
|
||||
});
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
@ -1582,7 +1593,7 @@ impl File {
|
|||
.await
|
||||
{
|
||||
log::error!("error closing remote buffer: {}", error);
|
||||
};
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
|
|
@ -6,9 +6,9 @@ message Envelope {
|
|||
optional uint32 responding_to = 2;
|
||||
optional uint32 original_sender_id = 3;
|
||||
oneof payload {
|
||||
Error error = 4;
|
||||
Ping ping = 5;
|
||||
Pong pong = 6;
|
||||
Ack ack = 4;
|
||||
Error error = 5;
|
||||
Ping ping = 6;
|
||||
ShareWorktree share_worktree = 7;
|
||||
ShareWorktreeResponse share_worktree_response = 8;
|
||||
OpenWorktree open_worktree = 9;
|
||||
|
@ -40,13 +40,9 @@ message Envelope {
|
|||
|
||||
// Messages
|
||||
|
||||
message Ping {
|
||||
int32 id = 1;
|
||||
}
|
||||
message Ping {}
|
||||
|
||||
message Pong {
|
||||
int32 id = 2;
|
||||
}
|
||||
message Ack {}
|
||||
|
||||
message Error {
|
||||
string message = 1;
|
||||
|
|
101
zrpc/src/conn.rs
Normal file
101
zrpc/src/conn.rs
Normal file
|
@ -0,0 +1,101 @@
|
|||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||
use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
|
||||
use std::{io, task::Poll};
|
||||
|
||||
pub struct Conn {
|
||||
pub(crate) tx:
|
||||
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
|
||||
pub(crate) rx: Box<
|
||||
dyn 'static
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
|
||||
>,
|
||||
}
|
||||
|
||||
impl Conn {
|
||||
pub fn new<S>(stream: S) -> Self
|
||||
where
|
||||
S: 'static
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
|
||||
{
|
||||
let (tx, rx) = stream.split();
|
||||
Self {
|
||||
tx: Box::new(tx),
|
||||
rx: Box::new(rx),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
|
||||
self.tx.send(message).await
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
|
||||
let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
|
||||
postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
|
||||
|
||||
let (a_tx, a_rx) = Self::channel(kill_rx.clone());
|
||||
let (b_tx, b_rx) = Self::channel(kill_rx);
|
||||
(
|
||||
Self { tx: a_tx, rx: b_rx },
|
||||
Self { tx: b_tx, rx: a_rx },
|
||||
kill_tx,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
fn channel(
|
||||
kill_rx: postage::watch::Receiver<Option<()>>,
|
||||
) -> (
|
||||
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
|
||||
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
|
||||
) {
|
||||
use futures::{future, SinkExt as _};
|
||||
use io::{Error, ErrorKind};
|
||||
|
||||
let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
|
||||
let tx = tx
|
||||
.sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
|
||||
.with({
|
||||
let kill_rx = kill_rx.clone();
|
||||
move |msg| {
|
||||
if kill_rx.borrow().is_none() {
|
||||
future::ready(Ok(msg))
|
||||
} else {
|
||||
future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
|
||||
}
|
||||
}
|
||||
});
|
||||
let rx = KillableReceiver { kill_rx, rx };
|
||||
|
||||
(Box::new(tx), Box::new(rx))
|
||||
}
|
||||
}
|
||||
|
||||
struct KillableReceiver {
|
||||
rx: mpsc::UnboundedReceiver<WebSocketMessage>,
|
||||
kill_rx: postage::watch::Receiver<Option<()>>,
|
||||
}
|
||||
|
||||
impl Stream for KillableReceiver {
|
||||
type Item = Result<WebSocketMessage, WebSocketError>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"connection killed",
|
||||
)
|
||||
.into())))
|
||||
} else {
|
||||
self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
pub mod auth;
|
||||
mod conn;
|
||||
mod peer;
|
||||
pub mod proto;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
||||
pub use conn::Conn;
|
||||
pub use peer::*;
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
||||
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
||||
use super::Conn;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_lock::{Mutex, RwLock};
|
||||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use futures::FutureExt as _;
|
||||
use postage::{
|
||||
mpsc,
|
||||
prelude::{Sink as _, Stream as _},
|
||||
|
@ -98,21 +98,14 @@ impl Peer {
|
|||
})
|
||||
}
|
||||
|
||||
pub async fn add_connection<Conn>(
|
||||
pub async fn add_connection(
|
||||
self: &Arc<Self>,
|
||||
conn: Conn,
|
||||
) -> (
|
||||
ConnectionId,
|
||||
impl Future<Output = anyhow::Result<()>> + Send,
|
||||
mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
|
||||
)
|
||||
where
|
||||
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
||||
+ Send
|
||||
+ Unpin,
|
||||
{
|
||||
let (tx, rx) = conn.split();
|
||||
) {
|
||||
let connection_id = ConnectionId(
|
||||
self.next_connection_id
|
||||
.fetch_add(1, atomic::Ordering::SeqCst),
|
||||
|
@ -124,9 +117,10 @@ impl Peer {
|
|||
next_message_id: Default::default(),
|
||||
response_channels: Default::default(),
|
||||
};
|
||||
let mut writer = MessageStream::new(tx);
|
||||
let mut reader = MessageStream::new(rx);
|
||||
let mut writer = MessageStream::new(conn.tx);
|
||||
let mut reader = MessageStream::new(conn.rx);
|
||||
|
||||
let this = self.clone();
|
||||
let response_channels = connection.response_channels.clone();
|
||||
let handle_io = async move {
|
||||
loop {
|
||||
|
@ -147,6 +141,7 @@ impl Peer {
|
|||
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
|
||||
if incoming_tx.send(envelope).await.is_err() {
|
||||
response_channels.lock().await.clear();
|
||||
this.connections.write().await.remove(&connection_id);
|
||||
return Ok(())
|
||||
}
|
||||
} else {
|
||||
|
@ -158,6 +153,7 @@ impl Peer {
|
|||
}
|
||||
Err(error) => {
|
||||
response_channels.lock().await.clear();
|
||||
this.connections.write().await.remove(&connection_id);
|
||||
Err(error).context("received invalid RPC message")?;
|
||||
}
|
||||
},
|
||||
|
@ -165,11 +161,13 @@ impl Peer {
|
|||
Some(outgoing) => {
|
||||
if let Err(result) = writer.write_message(&outgoing).await {
|
||||
response_channels.lock().await.clear();
|
||||
this.connections.write().await.remove(&connection_id);
|
||||
Err(result).context("failed to write RPC message")?;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
response_channels.lock().await.clear();
|
||||
this.connections.write().await.remove(&connection_id);
|
||||
return Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -342,7 +340,9 @@ impl Peer {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{test, TypedEnvelope};
|
||||
use crate::TypedEnvelope;
|
||||
use async_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||
use futures::StreamExt as _;
|
||||
|
||||
#[test]
|
||||
fn test_request_response() {
|
||||
|
@ -352,12 +352,12 @@ mod tests {
|
|||
let client1 = Peer::new();
|
||||
let client2 = Peer::new();
|
||||
|
||||
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
|
||||
let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
|
||||
let (client1_conn_id, io_task1, _) =
|
||||
client1.add_connection(client1_to_server_conn).await;
|
||||
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
|
||||
|
||||
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
|
||||
let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
|
||||
let (client2_conn_id, io_task3, _) =
|
||||
client2.add_connection(client2_to_server_conn).await;
|
||||
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
|
||||
|
@ -371,18 +371,18 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
client1
|
||||
.request(client1_conn_id, proto::Ping { id: 1 },)
|
||||
.request(client1_conn_id, proto::Ping {},)
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::Pong { id: 1 }
|
||||
proto::Ack {}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client2
|
||||
.request(client2_conn_id, proto::Ping { id: 2 },)
|
||||
.request(client2_conn_id, proto::Ping {},)
|
||||
.await
|
||||
.unwrap(),
|
||||
proto::Pong { id: 2 }
|
||||
proto::Ack {}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
|
@ -438,13 +438,7 @@ mod tests {
|
|||
let envelope = envelope.into_any();
|
||||
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
|
||||
let receipt = envelope.receipt();
|
||||
peer.respond(
|
||||
receipt,
|
||||
proto::Pong {
|
||||
id: envelope.payload.id,
|
||||
},
|
||||
)
|
||||
.await?
|
||||
peer.respond(receipt, proto::Ack {}).await?
|
||||
} else if let Some(envelope) =
|
||||
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
|
||||
{
|
||||
|
@ -492,7 +486,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_disconnect() {
|
||||
smol::block_on(async move {
|
||||
let (client_conn, mut server_conn) = test::Channel::bidirectional();
|
||||
let (client_conn, mut server_conn, _) = Conn::in_memory();
|
||||
|
||||
let client = Peer::new();
|
||||
let (connection_id, io_handler, mut incoming) =
|
||||
|
@ -516,18 +510,17 @@ mod tests {
|
|||
|
||||
io_ended_rx.recv().await;
|
||||
messages_ended_rx.recv().await;
|
||||
assert!(
|
||||
futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
|
||||
assert!(server_conn
|
||||
.send(WebSocketMessage::Binary(vec![]))
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
.is_err());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_io_error() {
|
||||
smol::block_on(async move {
|
||||
let (client_conn, server_conn) = test::Channel::bidirectional();
|
||||
let (client_conn, server_conn, _) = Conn::in_memory();
|
||||
drop(server_conn);
|
||||
|
||||
let client = Peer::new();
|
||||
|
@ -537,7 +530,7 @@ mod tests {
|
|||
smol::spawn(async move { incoming.next().await }).detach();
|
||||
|
||||
let err = client
|
||||
.request(connection_id, proto::Ping { id: 42 })
|
||||
.request(connection_id, proto::Ping {})
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.to_string(), "connection was closed");
|
||||
|
|
|
@ -120,6 +120,7 @@ macro_rules! entity_messages {
|
|||
}
|
||||
|
||||
messages!(
|
||||
Ack,
|
||||
AddPeer,
|
||||
BufferSaved,
|
||||
ChannelMessageSent,
|
||||
|
@ -140,7 +141,6 @@ messages!(
|
|||
OpenWorktree,
|
||||
OpenWorktreeResponse,
|
||||
Ping,
|
||||
Pong,
|
||||
RemovePeer,
|
||||
SaveBuffer,
|
||||
SendChannelMessage,
|
||||
|
@ -157,8 +157,9 @@ request_messages!(
|
|||
(JoinChannel, JoinChannelResponse),
|
||||
(OpenBuffer, OpenBufferResponse),
|
||||
(OpenWorktree, OpenWorktreeResponse),
|
||||
(Ping, Pong),
|
||||
(Ping, Ack),
|
||||
(SaveBuffer, BufferSaved),
|
||||
(UpdateBuffer, Ack),
|
||||
(ShareWorktree, ShareWorktreeResponse),
|
||||
(SendChannelMessage, SendChannelMessageResponse),
|
||||
(GetChannelMessages, GetChannelMessagesResponse),
|
||||
|
@ -247,30 +248,3 @@ impl From<SystemTime> for Timestamp {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::test;
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_message() {
|
||||
smol::block_on(async {
|
||||
let stream = test::Channel::new();
|
||||
let message1 = Ping { id: 5 }.into_envelope(3, None, None);
|
||||
let message2 = OpenBuffer {
|
||||
worktree_id: 0,
|
||||
path: "some/path".to_string(),
|
||||
}
|
||||
.into_envelope(5, None, None);
|
||||
|
||||
let mut message_stream = MessageStream::new(stream);
|
||||
message_stream.write_message(&message1).await.unwrap();
|
||||
message_stream.write_message(&message2).await.unwrap();
|
||||
let decoded_message1 = message_stream.read_message().await.unwrap();
|
||||
let decoded_message2 = message_stream.read_message().await.unwrap();
|
||||
assert_eq!(decoded_message1, message1);
|
||||
assert_eq!(decoded_message2, message2);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||
use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
pub struct Channel {
|
||||
tx: futures::channel::mpsc::UnboundedSender<WebSocketMessage>,
|
||||
rx: futures::channel::mpsc::UnboundedReceiver<WebSocketMessage>,
|
||||
}
|
||||
|
||||
impl Channel {
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
||||
Self { tx, rx }
|
||||
}
|
||||
|
||||
pub fn bidirectional() -> (Self, Self) {
|
||||
let (a_tx, a_rx) = futures::channel::mpsc::unbounded();
|
||||
let (b_tx, b_rx) = futures::channel::mpsc::unbounded();
|
||||
let a = Self { tx: a_tx, rx: b_rx };
|
||||
let b = Self { tx: b_tx, rx: a_rx };
|
||||
(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl futures::Sink<WebSocketMessage> for Channel {
|
||||
type Error = WebSocketError;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.tx)
|
||||
.poll_ready(cx)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> {
|
||||
Pin::new(&mut self.tx)
|
||||
.start_send(item)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.tx)
|
||||
.poll_flush(cx)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.tx)
|
||||
.poll_close(cx)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl futures::Stream for Channel {
|
||||
type Item = Result<WebSocketMessage, WebSocketError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
Pin::new(&mut self.rx)
|
||||
.poll_next(cx)
|
||||
.map(|i| i.map(|i| Ok(i)))
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue