Merge commit '680b86b17c
' into main
This commit is contained in:
commit
3d4a451c15
14 changed files with 1045 additions and 439 deletions
|
@ -3,8 +3,9 @@ use async_task::Runnable;
|
||||||
pub use async_task::Task;
|
pub use async_task::Task;
|
||||||
use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
|
use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
use postage::{barrier, prelude::Stream as _};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use smol::{channel, prelude::*, Executor};
|
use smol::{channel, prelude::*, Executor, Timer};
|
||||||
use std::{
|
use std::{
|
||||||
fmt::{self, Debug},
|
fmt::{self, Debug},
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
|
@ -18,7 +19,7 @@ use std::{
|
||||||
},
|
},
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
thread,
|
thread,
|
||||||
time::Duration,
|
time::{Duration, Instant},
|
||||||
};
|
};
|
||||||
use waker_fn::waker_fn;
|
use waker_fn::waker_fn;
|
||||||
|
|
||||||
|
@ -49,6 +50,8 @@ struct DeterministicState {
|
||||||
spawned_from_foreground: Vec<(Runnable, Backtrace)>,
|
spawned_from_foreground: Vec<(Runnable, Backtrace)>,
|
||||||
forbid_parking: bool,
|
forbid_parking: bool,
|
||||||
block_on_ticks: RangeInclusive<usize>,
|
block_on_ticks: RangeInclusive<usize>,
|
||||||
|
now: Instant,
|
||||||
|
pending_timers: Vec<(Instant, barrier::Sender)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Deterministic {
|
pub struct Deterministic {
|
||||||
|
@ -67,6 +70,8 @@ impl Deterministic {
|
||||||
spawned_from_foreground: Default::default(),
|
spawned_from_foreground: Default::default(),
|
||||||
forbid_parking: false,
|
forbid_parking: false,
|
||||||
block_on_ticks: 0..=1000,
|
block_on_ticks: 0..=1000,
|
||||||
|
now: Instant::now(),
|
||||||
|
pending_timers: Default::default(),
|
||||||
})),
|
})),
|
||||||
parker: Default::default(),
|
parker: Default::default(),
|
||||||
}
|
}
|
||||||
|
@ -119,17 +124,39 @@ impl Deterministic {
|
||||||
T: 'static,
|
T: 'static,
|
||||||
F: Future<Output = T> + 'static,
|
F: Future<Output = T> + 'static,
|
||||||
{
|
{
|
||||||
smol::pin!(future);
|
|
||||||
|
|
||||||
let unparker = self.parker.lock().unparker();
|
|
||||||
let woken = Arc::new(AtomicBool::new(false));
|
let woken = Arc::new(AtomicBool::new(false));
|
||||||
let waker = {
|
let mut future = Box::pin(future);
|
||||||
let woken = woken.clone();
|
loop {
|
||||||
waker_fn(move || {
|
if let Some(result) = self.run_internal(woken.clone(), &mut future) {
|
||||||
woken.store(true, SeqCst);
|
return result;
|
||||||
unparker.unpark();
|
}
|
||||||
})
|
|
||||||
};
|
if !woken.load(SeqCst) && self.state.lock().forbid_parking {
|
||||||
|
panic!("deterministic executor parked after a call to forbid_parking");
|
||||||
|
}
|
||||||
|
|
||||||
|
woken.store(false, SeqCst);
|
||||||
|
self.parker.lock().park();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_until_parked(&self) {
|
||||||
|
let woken = Arc::new(AtomicBool::new(false));
|
||||||
|
let future = std::future::pending::<()>();
|
||||||
|
smol::pin!(future);
|
||||||
|
self.run_internal(woken, future);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_internal<F, T>(&self, woken: Arc<AtomicBool>, mut future: F) -> Option<T>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
F: Future<Output = T> + Unpin,
|
||||||
|
{
|
||||||
|
let unparker = self.parker.lock().unparker();
|
||||||
|
let waker = waker_fn(move || {
|
||||||
|
woken.store(true, SeqCst);
|
||||||
|
unparker.unpark();
|
||||||
|
});
|
||||||
|
|
||||||
let mut cx = Context::from_waker(&waker);
|
let mut cx = Context::from_waker(&waker);
|
||||||
let mut trace = Trace::default();
|
let mut trace = Trace::default();
|
||||||
|
@ -163,23 +190,17 @@ impl Deterministic {
|
||||||
runnable.run();
|
runnable.run();
|
||||||
} else {
|
} else {
|
||||||
drop(state);
|
drop(state);
|
||||||
if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
|
if let Poll::Ready(result) = future.poll(&mut cx) {
|
||||||
return result;
|
return Some(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
let state = self.state.lock();
|
let state = self.state.lock();
|
||||||
if state.scheduled_from_foreground.is_empty()
|
if state.scheduled_from_foreground.is_empty()
|
||||||
&& state.scheduled_from_background.is_empty()
|
&& state.scheduled_from_background.is_empty()
|
||||||
&& state.spawned_from_foreground.is_empty()
|
&& state.spawned_from_foreground.is_empty()
|
||||||
{
|
{
|
||||||
if state.forbid_parking && !woken.load(SeqCst) {
|
return None;
|
||||||
panic!("deterministic executor parked after a call to forbid_parking");
|
|
||||||
}
|
|
||||||
drop(state);
|
|
||||||
woken.store(false, SeqCst);
|
|
||||||
self.parker.lock().park();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -407,6 +428,41 @@ impl Foreground {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn timer(&self, duration: Duration) {
|
||||||
|
match self {
|
||||||
|
Self::Deterministic(executor) => {
|
||||||
|
let (tx, mut rx) = barrier::channel();
|
||||||
|
{
|
||||||
|
let mut state = executor.state.lock();
|
||||||
|
let wakeup_at = state.now + duration;
|
||||||
|
state.pending_timers.push((wakeup_at, tx));
|
||||||
|
}
|
||||||
|
rx.recv().await;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
Timer::after(duration).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn advance_clock(&self, duration: Duration) {
|
||||||
|
match self {
|
||||||
|
Self::Deterministic(executor) => {
|
||||||
|
executor.run_until_parked();
|
||||||
|
|
||||||
|
let mut state = executor.state.lock();
|
||||||
|
state.now += duration;
|
||||||
|
let now = state.now;
|
||||||
|
let mut pending_timers = mem::take(&mut state.pending_timers);
|
||||||
|
drop(state);
|
||||||
|
|
||||||
|
pending_timers.retain(|(wakeup, _)| *wakeup > now);
|
||||||
|
executor.state.lock().pending_timers.extend(pending_timers);
|
||||||
|
}
|
||||||
|
_ => panic!("this method can only be called on a deterministic executor"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
|
pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
|
||||||
match self {
|
match self {
|
||||||
Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
|
Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
|
||||||
|
|
|
@ -5,10 +5,7 @@ use super::{
|
||||||
};
|
};
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use async_std::{sync::RwLock, task};
|
use async_std::{sync::RwLock, task};
|
||||||
use async_tungstenite::{
|
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
|
||||||
tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
|
|
||||||
WebSocketStream,
|
|
||||||
};
|
|
||||||
use futures::{future::BoxFuture, FutureExt};
|
use futures::{future::BoxFuture, FutureExt};
|
||||||
use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
|
use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
|
||||||
use sha1::{Digest as _, Sha1};
|
use sha1::{Digest as _, Sha1};
|
||||||
|
@ -30,7 +27,7 @@ use time::OffsetDateTime;
|
||||||
use zrpc::{
|
use zrpc::{
|
||||||
auth::random_token,
|
auth::random_token,
|
||||||
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
|
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
|
||||||
ConnectionId, Peer, TypedEnvelope,
|
Conn, ConnectionId, Peer, TypedEnvelope,
|
||||||
};
|
};
|
||||||
|
|
||||||
type ReplicaId = u16;
|
type ReplicaId = u16;
|
||||||
|
@ -95,6 +92,7 @@ impl Server {
|
||||||
};
|
};
|
||||||
|
|
||||||
server
|
server
|
||||||
|
.add_handler(Server::ping)
|
||||||
.add_handler(Server::share_worktree)
|
.add_handler(Server::share_worktree)
|
||||||
.add_handler(Server::join_worktree)
|
.add_handler(Server::join_worktree)
|
||||||
.add_handler(Server::update_worktree)
|
.add_handler(Server::update_worktree)
|
||||||
|
@ -133,19 +131,12 @@ impl Server {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn handle_connection<Conn>(
|
pub fn handle_connection(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
connection: Conn,
|
connection: Conn,
|
||||||
addr: String,
|
addr: String,
|
||||||
user_id: UserId,
|
user_id: UserId,
|
||||||
) -> impl Future<Output = ()>
|
) -> impl Future<Output = ()> {
|
||||||
where
|
|
||||||
Conn: 'static
|
|
||||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
|
||||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
|
||||||
+ Send
|
|
||||||
+ Unpin,
|
|
||||||
{
|
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
async move {
|
async move {
|
||||||
let (connection_id, handle_io, mut incoming_rx) =
|
let (connection_id, handle_io, mut incoming_rx) =
|
||||||
|
@ -254,6 +245,11 @@ impl Server {
|
||||||
worktree_ids
|
worktree_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
|
||||||
|
self.peer.respond(request.receipt(), proto::Ack {}).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn share_worktree(
|
async fn share_worktree(
|
||||||
self: Arc<Server>,
|
self: Arc<Server>,
|
||||||
mut request: TypedEnvelope<proto::ShareWorktree>,
|
mut request: TypedEnvelope<proto::ShareWorktree>,
|
||||||
|
@ -503,7 +499,9 @@ impl Server {
|
||||||
request: TypedEnvelope<proto::UpdateBuffer>,
|
request: TypedEnvelope<proto::UpdateBuffer>,
|
||||||
) -> tide::Result<()> {
|
) -> tide::Result<()> {
|
||||||
self.broadcast_in_worktree(request.payload.worktree_id, &request)
|
self.broadcast_in_worktree(request.payload.worktree_id, &request)
|
||||||
.await
|
.await?;
|
||||||
|
self.peer.respond(request.receipt(), proto::Ack {}).await?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn buffer_saved(
|
async fn buffer_saved(
|
||||||
|
@ -974,8 +972,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
|
||||||
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
|
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
|
||||||
task::spawn(async move {
|
task::spawn(async move {
|
||||||
if let Some(stream) = upgrade_receiver.await {
|
if let Some(stream) = upgrade_receiver.await {
|
||||||
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
|
server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
|
||||||
server.handle_connection(stream, addr, user_id).await;
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1009,17 +1006,25 @@ mod tests {
|
||||||
};
|
};
|
||||||
use async_std::{sync::RwLockReadGuard, task};
|
use async_std::{sync::RwLockReadGuard, task};
|
||||||
use gpui::TestAppContext;
|
use gpui::TestAppContext;
|
||||||
use postage::mpsc;
|
use parking_lot::Mutex;
|
||||||
|
use postage::{mpsc, watch};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sqlx::types::time::OffsetDateTime;
|
use sqlx::types::time::OffsetDateTime;
|
||||||
use std::{path::Path, sync::Arc, time::Duration};
|
use std::{
|
||||||
|
path::Path,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering::SeqCst},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
use zed::{
|
use zed::{
|
||||||
channel::{Channel, ChannelDetails, ChannelList},
|
channel::{Channel, ChannelDetails, ChannelList},
|
||||||
editor::{Editor, Insert},
|
editor::{Editor, Insert},
|
||||||
fs::{FakeFs, Fs as _},
|
fs::{FakeFs, Fs as _},
|
||||||
language::LanguageRegistry,
|
language::LanguageRegistry,
|
||||||
rpc::Client,
|
rpc::{self, Client},
|
||||||
settings, test,
|
settings,
|
||||||
user::UserStore,
|
user::UserStore,
|
||||||
worktree::Worktree,
|
worktree::Worktree,
|
||||||
};
|
};
|
||||||
|
@ -1469,7 +1474,7 @@ mod tests {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Drop client B's connection and ensure client A observes client B leaving the worktree.
|
// Drop client B's connection and ensure client A observes client B leaving the worktree.
|
||||||
client_b.disconnect().await.unwrap();
|
client_b.disconnect(&cx_b.to_async()).await.unwrap();
|
||||||
worktree_a
|
worktree_a
|
||||||
.condition(&cx_a, |tree, _| tree.peers().len() == 0)
|
.condition(&cx_a, |tree, _| tree.peers().len() == 0)
|
||||||
.await;
|
.await;
|
||||||
|
@ -1675,11 +1680,206 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
|
||||||
|
cx_a.foreground().forbid_parking();
|
||||||
|
|
||||||
|
// Connect to a server as 2 clients.
|
||||||
|
let mut server = TestServer::start().await;
|
||||||
|
let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
|
||||||
|
let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
|
||||||
|
let mut status_b = client_b.status();
|
||||||
|
|
||||||
|
// Create an org that includes these 2 users.
|
||||||
|
let db = &server.app_state.db;
|
||||||
|
let org_id = db.create_org("Test Org", "test-org").await.unwrap();
|
||||||
|
db.add_org_member(org_id, user_id_a, false).await.unwrap();
|
||||||
|
db.add_org_member(org_id, user_id_b, false).await.unwrap();
|
||||||
|
|
||||||
|
// Create a channel that includes all the users.
|
||||||
|
let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
|
||||||
|
db.add_channel_member(channel_id, user_id_a, false)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
db.add_channel_member(channel_id, user_id_b, false)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
db.create_channel_message(
|
||||||
|
channel_id,
|
||||||
|
user_id_b,
|
||||||
|
"hello A, it's B.",
|
||||||
|
OffsetDateTime::now_utc(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let user_store_a = Arc::new(UserStore::new(client_a.clone()));
|
||||||
|
let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
|
||||||
|
channels_a
|
||||||
|
.condition(&mut cx_a, |list, _| list.available_channels().is_some())
|
||||||
|
.await;
|
||||||
|
|
||||||
|
channels_a.read_with(&cx_a, |list, _| {
|
||||||
|
assert_eq!(
|
||||||
|
list.available_channels().unwrap(),
|
||||||
|
&[ChannelDetails {
|
||||||
|
id: channel_id.to_proto(),
|
||||||
|
name: "test-channel".to_string()
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
});
|
||||||
|
let channel_a = channels_a.update(&mut cx_a, |this, cx| {
|
||||||
|
this.get_channel(channel_id.to_proto(), cx).unwrap()
|
||||||
|
});
|
||||||
|
channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
|
||||||
|
channel_a
|
||||||
|
.condition(&cx_a, |channel, _| {
|
||||||
|
channel_messages(channel)
|
||||||
|
== [("user_b".to_string(), "hello A, it's B.".to_string())]
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let user_store_b = Arc::new(UserStore::new(client_b.clone()));
|
||||||
|
let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
|
||||||
|
channels_b
|
||||||
|
.condition(&mut cx_b, |list, _| list.available_channels().is_some())
|
||||||
|
.await;
|
||||||
|
channels_b.read_with(&cx_b, |list, _| {
|
||||||
|
assert_eq!(
|
||||||
|
list.available_channels().unwrap(),
|
||||||
|
&[ChannelDetails {
|
||||||
|
id: channel_id.to_proto(),
|
||||||
|
name: "test-channel".to_string()
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
let channel_b = channels_b.update(&mut cx_b, |this, cx| {
|
||||||
|
this.get_channel(channel_id.to_proto(), cx).unwrap()
|
||||||
|
});
|
||||||
|
channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
|
||||||
|
channel_b
|
||||||
|
.condition(&cx_b, |channel, _| {
|
||||||
|
channel_messages(channel)
|
||||||
|
== [("user_b".to_string(), "hello A, it's B.".to_string())]
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Disconnect client B, ensuring we can still access its cached channel data.
|
||||||
|
server.forbid_connections();
|
||||||
|
server.disconnect_client(user_id_b);
|
||||||
|
while !matches!(
|
||||||
|
status_b.recv().await,
|
||||||
|
Some(rpc::Status::ReconnectionError { .. })
|
||||||
|
) {}
|
||||||
|
|
||||||
|
channels_b.read_with(&cx_b, |channels, _| {
|
||||||
|
assert_eq!(
|
||||||
|
channels.available_channels().unwrap(),
|
||||||
|
[ChannelDetails {
|
||||||
|
id: channel_id.to_proto(),
|
||||||
|
name: "test-channel".to_string()
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
});
|
||||||
|
channel_b.read_with(&cx_b, |channel, _| {
|
||||||
|
assert_eq!(
|
||||||
|
channel_messages(channel),
|
||||||
|
[("user_b".to_string(), "hello A, it's B.".to_string())]
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send a message from client A while B is disconnected.
|
||||||
|
channel_a
|
||||||
|
.update(&mut cx_a, |channel, cx| {
|
||||||
|
channel
|
||||||
|
.send_message("oh, hi B.".to_string(), cx)
|
||||||
|
.unwrap()
|
||||||
|
.detach();
|
||||||
|
let task = channel.send_message("sup".to_string(), cx).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
channel
|
||||||
|
.pending_messages()
|
||||||
|
.iter()
|
||||||
|
.map(|m| &m.body)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
&["oh, hi B.", "sup"]
|
||||||
|
);
|
||||||
|
task
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Give client B a chance to reconnect.
|
||||||
|
server.allow_connections();
|
||||||
|
cx_b.foreground().advance_clock(Duration::from_secs(10));
|
||||||
|
|
||||||
|
// Verify that B sees the new messages upon reconnection.
|
||||||
|
channel_b
|
||||||
|
.condition(&cx_b, |channel, _| {
|
||||||
|
channel_messages(channel)
|
||||||
|
== [
|
||||||
|
("user_b".to_string(), "hello A, it's B.".to_string()),
|
||||||
|
("user_a".to_string(), "oh, hi B.".to_string()),
|
||||||
|
("user_a".to_string(), "sup".to_string()),
|
||||||
|
]
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Ensure client A and B can communicate normally after reconnection.
|
||||||
|
channel_a
|
||||||
|
.update(&mut cx_a, |channel, cx| {
|
||||||
|
channel.send_message("you online?".to_string(), cx).unwrap()
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
channel_b
|
||||||
|
.condition(&cx_b, |channel, _| {
|
||||||
|
channel_messages(channel)
|
||||||
|
== [
|
||||||
|
("user_b".to_string(), "hello A, it's B.".to_string()),
|
||||||
|
("user_a".to_string(), "oh, hi B.".to_string()),
|
||||||
|
("user_a".to_string(), "sup".to_string()),
|
||||||
|
("user_a".to_string(), "you online?".to_string()),
|
||||||
|
]
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
channel_b
|
||||||
|
.update(&mut cx_b, |channel, cx| {
|
||||||
|
channel.send_message("yep".to_string(), cx).unwrap()
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
channel_a
|
||||||
|
.condition(&cx_a, |channel, _| {
|
||||||
|
channel_messages(channel)
|
||||||
|
== [
|
||||||
|
("user_b".to_string(), "hello A, it's B.".to_string()),
|
||||||
|
("user_a".to_string(), "oh, hi B.".to_string()),
|
||||||
|
("user_a".to_string(), "sup".to_string()),
|
||||||
|
("user_a".to_string(), "you online?".to_string()),
|
||||||
|
("user_b".to_string(), "yep".to_string()),
|
||||||
|
]
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
|
||||||
|
channel
|
||||||
|
.messages()
|
||||||
|
.cursor::<(), ()>()
|
||||||
|
.map(|m| (m.sender.github_login.clone(), m.body.clone()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct TestServer {
|
struct TestServer {
|
||||||
peer: Arc<Peer>,
|
peer: Arc<Peer>,
|
||||||
app_state: Arc<AppState>,
|
app_state: Arc<AppState>,
|
||||||
server: Arc<Server>,
|
server: Arc<Server>,
|
||||||
notifications: mpsc::Receiver<()>,
|
notifications: mpsc::Receiver<()>,
|
||||||
|
connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
|
||||||
|
forbid_connections: Arc<AtomicBool>,
|
||||||
_test_db: TestDb,
|
_test_db: TestDb,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1695,6 +1895,8 @@ mod tests {
|
||||||
app_state,
|
app_state,
|
||||||
server,
|
server,
|
||||||
notifications: notifications.1,
|
notifications: notifications.1,
|
||||||
|
connection_killers: Default::default(),
|
||||||
|
forbid_connections: Default::default(),
|
||||||
_test_db: test_db,
|
_test_db: test_db,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1704,20 +1906,67 @@ mod tests {
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
name: &str,
|
name: &str,
|
||||||
) -> (UserId, Arc<Client>) {
|
) -> (UserId, Arc<Client>) {
|
||||||
let user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
|
||||||
let client = Client::new();
|
let client_name = name.to_string();
|
||||||
let (client_conn, server_conn) = test::Channel::bidirectional();
|
let mut client = Client::new();
|
||||||
cx.background()
|
let server = self.server.clone();
|
||||||
.spawn(
|
let connection_killers = self.connection_killers.clone();
|
||||||
self.server
|
let forbid_connections = self.forbid_connections.clone();
|
||||||
.handle_connection(server_conn, name.to_string(), user_id),
|
Arc::get_mut(&mut client)
|
||||||
)
|
.unwrap()
|
||||||
.detach();
|
.set_login_and_connect_callbacks(
|
||||||
|
move |cx| {
|
||||||
|
cx.spawn(|_| async move {
|
||||||
|
let access_token = "the-token".to_string();
|
||||||
|
Ok((client_user_id.0 as u64, access_token))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
move |user_id, access_token, cx| {
|
||||||
|
assert_eq!(user_id, client_user_id.0 as u64);
|
||||||
|
assert_eq!(access_token, "the-token");
|
||||||
|
|
||||||
|
let server = server.clone();
|
||||||
|
let connection_killers = connection_killers.clone();
|
||||||
|
let forbid_connections = forbid_connections.clone();
|
||||||
|
let client_name = client_name.clone();
|
||||||
|
cx.spawn(move |cx| async move {
|
||||||
|
if forbid_connections.load(SeqCst) {
|
||||||
|
Err(anyhow!("server is forbidding connections"))
|
||||||
|
} else {
|
||||||
|
let (client_conn, server_conn, kill_conn) = Conn::in_memory();
|
||||||
|
connection_killers.lock().insert(client_user_id, kill_conn);
|
||||||
|
cx.background()
|
||||||
|
.spawn(server.handle_connection(
|
||||||
|
server_conn,
|
||||||
|
client_name,
|
||||||
|
client_user_id,
|
||||||
|
))
|
||||||
|
.detach();
|
||||||
|
Ok(client_conn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
client
|
client
|
||||||
.add_connection(user_id.to_proto(), client_conn, &cx.to_async())
|
.authenticate_and_connect(&cx.to_async())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
(user_id, client)
|
(client_user_id, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn disconnect_client(&self, user_id: UserId) {
|
||||||
|
if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
|
||||||
|
let _ = kill_conn.try_send(Some(()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forbid_connections(&self) {
|
||||||
|
self.forbid_connections.store(true, SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn allow_connections(&self) {
|
||||||
|
self.forbid_connections.store(false, SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
|
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
|
||||||
|
|
|
@ -11,6 +11,7 @@ use gpui::{
|
||||||
use postage::prelude::Stream;
|
use postage::prelude::Stream;
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
|
mem,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
@ -71,7 +72,7 @@ pub enum ChannelListEvent {}
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq)]
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
pub enum ChannelEvent {
|
pub enum ChannelEvent {
|
||||||
MessagesAdded {
|
MessagesUpdated {
|
||||||
old_range: Range<usize>,
|
old_range: Range<usize>,
|
||||||
new_count: usize,
|
new_count: usize,
|
||||||
},
|
},
|
||||||
|
@ -87,36 +88,47 @@ impl ChannelList {
|
||||||
rpc: Arc<rpc::Client>,
|
rpc: Arc<rpc::Client>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let _task = cx.spawn(|this, mut cx| {
|
let _task = cx.spawn_weak(|this, mut cx| {
|
||||||
let rpc = rpc.clone();
|
let rpc = rpc.clone();
|
||||||
async move {
|
async move {
|
||||||
let mut user_id = rpc.user_id();
|
let mut status = rpc.status();
|
||||||
loop {
|
while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
|
||||||
let available_channels = if user_id.recv().await.unwrap().is_some() {
|
match status {
|
||||||
Some(
|
rpc::Status::Connected { .. } => {
|
||||||
rpc.request(proto::GetChannels {})
|
let response = rpc
|
||||||
|
.request(proto::GetChannels {})
|
||||||
.await
|
.await
|
||||||
.context("failed to fetch available channels")?
|
.context("failed to fetch available channels")?;
|
||||||
.channels
|
this.update(&mut cx, |this, cx| {
|
||||||
.into_iter()
|
this.available_channels =
|
||||||
.map(Into::into)
|
Some(response.channels.into_iter().map(Into::into).collect());
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
let mut to_remove = Vec::new();
|
||||||
if available_channels.is_none() {
|
for (channel_id, channel) in &this.channels {
|
||||||
if this.available_channels.is_none() {
|
if let Some(channel) = channel.upgrade(cx) {
|
||||||
return;
|
channel.update(cx, |channel, cx| channel.rejoin(cx))
|
||||||
}
|
} else {
|
||||||
this.channels.clear();
|
to_remove.push(*channel_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for channel_id in to_remove {
|
||||||
|
this.channels.remove(&channel_id);
|
||||||
|
}
|
||||||
|
cx.notify();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
this.available_channels = available_channels;
|
rpc::Status::Disconnected { .. } => {
|
||||||
cx.notify();
|
this.update(&mut cx, |this, cx| {
|
||||||
});
|
this.available_channels = None;
|
||||||
|
this.channels.clear();
|
||||||
|
cx.notify();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
.log_err()
|
.log_err()
|
||||||
});
|
});
|
||||||
|
@ -285,6 +297,43 @@ impl Channel {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
|
let user_store = self.user_store.clone();
|
||||||
|
let rpc = self.rpc.clone();
|
||||||
|
let channel_id = self.details.id;
|
||||||
|
cx.spawn(|channel, mut cx| {
|
||||||
|
async move {
|
||||||
|
let response = rpc.request(proto::JoinChannel { channel_id }).await?;
|
||||||
|
let messages = messages_from_proto(response.messages, &user_store).await?;
|
||||||
|
let loaded_all_messages = response.done;
|
||||||
|
|
||||||
|
channel.update(&mut cx, |channel, cx| {
|
||||||
|
if let Some((first_new_message, last_old_message)) =
|
||||||
|
messages.first().zip(channel.messages.last())
|
||||||
|
{
|
||||||
|
if first_new_message.id > last_old_message.id {
|
||||||
|
let old_messages = mem::take(&mut channel.messages);
|
||||||
|
cx.emit(ChannelEvent::MessagesUpdated {
|
||||||
|
old_range: 0..old_messages.summary().count,
|
||||||
|
new_count: 0,
|
||||||
|
});
|
||||||
|
channel.loaded_all_messages = loaded_all_messages;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
channel.insert_messages(messages, cx);
|
||||||
|
if loaded_all_messages {
|
||||||
|
channel.loaded_all_messages = loaded_all_messages;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
.log_err()
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
|
||||||
pub fn message_count(&self) -> usize {
|
pub fn message_count(&self) -> usize {
|
||||||
self.messages.summary().count
|
self.messages.summary().count
|
||||||
}
|
}
|
||||||
|
@ -350,7 +399,7 @@ impl Channel {
|
||||||
drop(old_cursor);
|
drop(old_cursor);
|
||||||
self.messages = new_messages;
|
self.messages = new_messages;
|
||||||
|
|
||||||
cx.emit(ChannelEvent::MessagesAdded {
|
cx.emit(ChannelEvent::MessagesUpdated {
|
||||||
old_range: start_ix..end_ix,
|
old_range: start_ix..end_ix,
|
||||||
new_count,
|
new_count,
|
||||||
});
|
});
|
||||||
|
@ -446,22 +495,21 @@ impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::test::FakeServer;
|
||||||
use gpui::TestAppContext;
|
use gpui::TestAppContext;
|
||||||
use postage::mpsc::Receiver;
|
|
||||||
use zrpc::{test::Channel, ConnectionId, Peer, Receipt};
|
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_channel_messages(mut cx: TestAppContext) {
|
async fn test_channel_messages(mut cx: TestAppContext) {
|
||||||
let user_id = 5;
|
let user_id = 5;
|
||||||
let client = Client::new();
|
let mut client = Client::new();
|
||||||
let mut server = FakeServer::for_client(user_id, &client, &cx).await;
|
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||||
let user_store = Arc::new(UserStore::new(client.clone()));
|
let user_store = Arc::new(UserStore::new(client.clone()));
|
||||||
|
|
||||||
let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
|
let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
|
||||||
channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
|
channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
|
||||||
|
|
||||||
// Get the available channels.
|
// Get the available channels.
|
||||||
let get_channels = server.receive::<proto::GetChannels>().await;
|
let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
|
||||||
server
|
server
|
||||||
.respond(
|
.respond(
|
||||||
get_channels.receipt(),
|
get_channels.receipt(),
|
||||||
|
@ -492,7 +540,7 @@ mod tests {
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
|
channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
|
||||||
let join_channel = server.receive::<proto::JoinChannel>().await;
|
let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
|
||||||
server
|
server
|
||||||
.respond(
|
.respond(
|
||||||
join_channel.receipt(),
|
join_channel.receipt(),
|
||||||
|
@ -517,7 +565,7 @@ mod tests {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Client requests all users for the received messages
|
// Client requests all users for the received messages
|
||||||
let mut get_users = server.receive::<proto::GetUsers>().await;
|
let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
|
||||||
get_users.payload.user_ids.sort();
|
get_users.payload.user_ids.sort();
|
||||||
assert_eq!(get_users.payload.user_ids, vec![5, 6]);
|
assert_eq!(get_users.payload.user_ids, vec![5, 6]);
|
||||||
server
|
server
|
||||||
|
@ -542,7 +590,7 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
channel.next_event(&cx).await,
|
channel.next_event(&cx).await,
|
||||||
ChannelEvent::MessagesAdded {
|
ChannelEvent::MessagesUpdated {
|
||||||
old_range: 0..0,
|
old_range: 0..0,
|
||||||
new_count: 2,
|
new_count: 2,
|
||||||
}
|
}
|
||||||
|
@ -574,7 +622,7 @@ mod tests {
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Client requests user for message since they haven't seen them yet
|
// Client requests user for message since they haven't seen them yet
|
||||||
let get_users = server.receive::<proto::GetUsers>().await;
|
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
|
||||||
assert_eq!(get_users.payload.user_ids, vec![7]);
|
assert_eq!(get_users.payload.user_ids, vec![7]);
|
||||||
server
|
server
|
||||||
.respond(
|
.respond(
|
||||||
|
@ -591,7 +639,7 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
channel.next_event(&cx).await,
|
channel.next_event(&cx).await,
|
||||||
ChannelEvent::MessagesAdded {
|
ChannelEvent::MessagesUpdated {
|
||||||
old_range: 2..2,
|
old_range: 2..2,
|
||||||
new_count: 1,
|
new_count: 1,
|
||||||
}
|
}
|
||||||
|
@ -610,7 +658,7 @@ mod tests {
|
||||||
channel.update(&mut cx, |channel, cx| {
|
channel.update(&mut cx, |channel, cx| {
|
||||||
assert!(channel.load_more_messages(cx));
|
assert!(channel.load_more_messages(cx));
|
||||||
});
|
});
|
||||||
let get_messages = server.receive::<proto::GetChannelMessages>().await;
|
let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
|
||||||
assert_eq!(get_messages.payload.channel_id, 5);
|
assert_eq!(get_messages.payload.channel_id, 5);
|
||||||
assert_eq!(get_messages.payload.before_message_id, 10);
|
assert_eq!(get_messages.payload.before_message_id, 10);
|
||||||
server
|
server
|
||||||
|
@ -638,7 +686,7 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
channel.next_event(&cx).await,
|
channel.next_event(&cx).await,
|
||||||
ChannelEvent::MessagesAdded {
|
ChannelEvent::MessagesUpdated {
|
||||||
old_range: 0..0,
|
old_range: 0..0,
|
||||||
new_count: 2,
|
new_count: 2,
|
||||||
}
|
}
|
||||||
|
@ -656,53 +704,4 @@ mod tests {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FakeServer {
|
|
||||||
peer: Arc<Peer>,
|
|
||||||
incoming: Receiver<Box<dyn proto::AnyTypedEnvelope>>,
|
|
||||||
connection_id: ConnectionId,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FakeServer {
|
|
||||||
async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
|
|
||||||
let (client_conn, server_conn) = Channel::bidirectional();
|
|
||||||
let peer = Peer::new();
|
|
||||||
let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
|
|
||||||
cx.background().spawn(io).detach();
|
|
||||||
|
|
||||||
client
|
|
||||||
.add_connection(user_id, client_conn, &cx.to_async())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
peer,
|
|
||||||
incoming,
|
|
||||||
connection_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
|
|
||||||
self.peer.send(self.connection_id, message).await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn receive<M: proto::EnvelopedMessage>(&mut self) -> TypedEnvelope<M> {
|
|
||||||
*self
|
|
||||||
.incoming
|
|
||||||
.recv()
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.into_any()
|
|
||||||
.downcast::<TypedEnvelope<M>>()
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn respond<T: proto::RequestMessage>(
|
|
||||||
&self,
|
|
||||||
receipt: Receipt<T>,
|
|
||||||
response: T::Response,
|
|
||||||
) {
|
|
||||||
self.peer.respond(receipt, response).await.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||||
use crate::{
|
use crate::{
|
||||||
channel::{Channel, ChannelEvent, ChannelList, ChannelMessage},
|
channel::{Channel, ChannelEvent, ChannelList, ChannelMessage},
|
||||||
editor::Editor,
|
editor::Editor,
|
||||||
rpc::Client,
|
rpc::{self, Client},
|
||||||
theme,
|
theme,
|
||||||
util::{ResultExt, TryFutureExt},
|
util::{ResultExt, TryFutureExt},
|
||||||
Settings,
|
Settings,
|
||||||
|
@ -14,10 +14,10 @@ use gpui::{
|
||||||
keymap::Binding,
|
keymap::Binding,
|
||||||
platform::CursorStyle,
|
platform::CursorStyle,
|
||||||
views::{ItemType, Select, SelectStyle},
|
views::{ItemType, Select, SelectStyle},
|
||||||
AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, View,
|
AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, Task, View,
|
||||||
ViewContext, ViewHandle,
|
ViewContext, ViewHandle,
|
||||||
};
|
};
|
||||||
use postage::watch;
|
use postage::{prelude::Stream, watch};
|
||||||
use time::{OffsetDateTime, UtcOffset};
|
use time::{OffsetDateTime, UtcOffset};
|
||||||
|
|
||||||
const MESSAGE_LOADING_THRESHOLD: usize = 50;
|
const MESSAGE_LOADING_THRESHOLD: usize = 50;
|
||||||
|
@ -31,6 +31,7 @@ pub struct ChatPanel {
|
||||||
channel_select: ViewHandle<Select>,
|
channel_select: ViewHandle<Select>,
|
||||||
settings: watch::Receiver<Settings>,
|
settings: watch::Receiver<Settings>,
|
||||||
local_timezone: UtcOffset,
|
local_timezone: UtcOffset,
|
||||||
|
_observe_status: Task<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum Event {}
|
pub enum Event {}
|
||||||
|
@ -98,6 +99,14 @@ impl ChatPanel {
|
||||||
cx.dispatch_action(LoadMoreMessages);
|
cx.dispatch_action(LoadMoreMessages);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
let _observe_status = cx.spawn(|this, mut cx| {
|
||||||
|
let mut status = rpc.status();
|
||||||
|
async move {
|
||||||
|
while let Some(_) = status.recv().await {
|
||||||
|
this.update(&mut cx, |_, cx| cx.notify());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
rpc,
|
rpc,
|
||||||
|
@ -108,6 +117,7 @@ impl ChatPanel {
|
||||||
channel_select,
|
channel_select,
|
||||||
settings,
|
settings,
|
||||||
local_timezone: cx.platform().local_timezone(),
|
local_timezone: cx.platform().local_timezone(),
|
||||||
|
_observe_status,
|
||||||
};
|
};
|
||||||
|
|
||||||
this.init_active_channel(cx);
|
this.init_active_channel(cx);
|
||||||
|
@ -153,6 +163,7 @@ impl ChatPanel {
|
||||||
if let Some(active_channel) = active_channel {
|
if let Some(active_channel) = active_channel {
|
||||||
self.set_active_channel(active_channel, cx);
|
self.set_active_channel(active_channel, cx);
|
||||||
} else {
|
} else {
|
||||||
|
self.message_list.reset(0);
|
||||||
self.active_channel = None;
|
self.active_channel = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,7 +194,7 @@ impl ChatPanel {
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
) {
|
) {
|
||||||
match event {
|
match event {
|
||||||
ChannelEvent::MessagesAdded {
|
ChannelEvent::MessagesUpdated {
|
||||||
old_range,
|
old_range,
|
||||||
new_count,
|
new_count,
|
||||||
} => {
|
} => {
|
||||||
|
@ -357,10 +368,6 @@ impl ChatPanel {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_signed_in(&self) -> bool {
|
|
||||||
self.rpc.user_id().borrow().is_some()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Entity for ChatPanel {
|
impl Entity for ChatPanel {
|
||||||
|
@ -374,10 +381,9 @@ impl View for ChatPanel {
|
||||||
|
|
||||||
fn render(&mut self, cx: &mut RenderContext<Self>) -> ElementBox {
|
fn render(&mut self, cx: &mut RenderContext<Self>) -> ElementBox {
|
||||||
let theme = &self.settings.borrow().theme;
|
let theme = &self.settings.borrow().theme;
|
||||||
let element = if self.is_signed_in() {
|
let element = match *self.rpc.status().borrow() {
|
||||||
self.render_channel()
|
rpc::Status::Connected { .. } => self.render_channel(),
|
||||||
} else {
|
_ => self.render_sign_in_prompt(cx),
|
||||||
self.render_sign_in_prompt(cx)
|
|
||||||
};
|
};
|
||||||
ConstrainedBox::new(
|
ConstrainedBox::new(
|
||||||
Container::new(element)
|
Container::new(element)
|
||||||
|
@ -389,7 +395,7 @@ impl View for ChatPanel {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_focus(&mut self, cx: &mut ViewContext<Self>) {
|
fn on_focus(&mut self, cx: &mut ViewContext<Self>) {
|
||||||
if self.is_signed_in() {
|
if matches!(*self.rpc.status().borrow(), rpc::Status::Connected { .. }) {
|
||||||
cx.focus(&self.input_editor);
|
cx.focus(&self.input_editor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2695,14 +2695,7 @@ impl<'a> Into<proto::operation::Edit> for &'a EditOperation {
|
||||||
impl<'a> Into<proto::Anchor> for &'a Anchor {
|
impl<'a> Into<proto::Anchor> for &'a Anchor {
|
||||||
fn into(self) -> proto::Anchor {
|
fn into(self) -> proto::Anchor {
|
||||||
proto::Anchor {
|
proto::Anchor {
|
||||||
version: self
|
version: (&self.version).into(),
|
||||||
.version
|
|
||||||
.iter()
|
|
||||||
.map(|entry| proto::VectorClockEntry {
|
|
||||||
replica_id: entry.replica_id as u32,
|
|
||||||
timestamp: entry.value,
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
offset: self.offset as u64,
|
offset: self.offset as u64,
|
||||||
bias: match self.bias {
|
bias: match self.bias {
|
||||||
Bias::Left => proto::anchor::Bias::Left as i32,
|
Bias::Left => proto::anchor::Bias::Left as i32,
|
||||||
|
|
424
zed/src/rpc.rs
424
zed/src/rpc.rs
|
@ -1,24 +1,24 @@
|
||||||
use crate::util::ResultExt;
|
use crate::util::ResultExt;
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use async_tungstenite::tungstenite::http::Request;
|
use async_tungstenite::tungstenite::http::Request;
|
||||||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
|
||||||
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
|
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use postage::prelude::Stream;
|
use postage::{prelude::Stream, watch};
|
||||||
use postage::sink::Sink;
|
use rand::prelude::*;
|
||||||
use postage::watch;
|
use std::{
|
||||||
use std::any::TypeId;
|
any::TypeId,
|
||||||
use std::collections::HashMap;
|
collections::HashMap,
|
||||||
use std::sync::Weak;
|
convert::TryFrom,
|
||||||
use std::time::{Duration, Instant};
|
future::Future,
|
||||||
use std::{convert::TryFrom, future::Future, sync::Arc};
|
sync::{Arc, Weak},
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
use surf::Url;
|
use surf::Url;
|
||||||
use zrpc::proto::{AnyTypedEnvelope, EntityMessage};
|
|
||||||
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
|
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
|
||||||
use zrpc::{
|
use zrpc::{
|
||||||
proto::{EnvelopedMessage, RequestMessage},
|
proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
|
||||||
Peer, Receipt,
|
Conn, Peer, Receipt,
|
||||||
};
|
};
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
|
@ -29,25 +29,55 @@ lazy_static! {
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
peer: Arc<Peer>,
|
peer: Arc<Peer>,
|
||||||
state: RwLock<ClientState>,
|
state: RwLock<ClientState>,
|
||||||
|
auth_callback: Option<
|
||||||
|
Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
|
||||||
|
>,
|
||||||
|
connect_callback: Option<
|
||||||
|
Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
|
||||||
|
>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
|
pub enum Status {
|
||||||
|
Disconnected,
|
||||||
|
Authenticating,
|
||||||
|
Connecting {
|
||||||
|
user_id: u64,
|
||||||
|
},
|
||||||
|
ConnectionError,
|
||||||
|
Connected {
|
||||||
|
connection_id: ConnectionId,
|
||||||
|
user_id: u64,
|
||||||
|
},
|
||||||
|
ConnectionLost,
|
||||||
|
Reauthenticating,
|
||||||
|
Reconnecting {
|
||||||
|
user_id: u64,
|
||||||
|
},
|
||||||
|
ReconnectionError {
|
||||||
|
next_reconnection: Instant,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ClientState {
|
struct ClientState {
|
||||||
connection_id: Option<ConnectionId>,
|
status: (watch::Sender<Status>, watch::Receiver<Status>),
|
||||||
user_id: (watch::Sender<Option<u64>>, watch::Receiver<Option<u64>>),
|
|
||||||
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
|
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
|
||||||
model_handlers: HashMap<
|
model_handlers: HashMap<
|
||||||
(TypeId, u64),
|
(TypeId, u64),
|
||||||
Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
|
Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
|
||||||
>,
|
>,
|
||||||
|
_maintain_connection: Option<Task<()>>,
|
||||||
|
heartbeat_interval: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ClientState {
|
impl Default for ClientState {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
connection_id: Default::default(),
|
status: watch::channel_with(Status::Disconnected),
|
||||||
user_id: watch::channel(),
|
|
||||||
entity_id_extractors: Default::default(),
|
entity_id_extractors: Default::default(),
|
||||||
model_handlers: Default::default(),
|
model_handlers: Default::default(),
|
||||||
|
_maintain_connection: None,
|
||||||
|
heartbeat_interval: Duration::from_secs(5),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,11 +107,71 @@ impl Client {
|
||||||
Arc::new(Self {
|
Arc::new(Self {
|
||||||
peer: Peer::new(),
|
peer: Peer::new(),
|
||||||
state: Default::default(),
|
state: Default::default(),
|
||||||
|
auth_callback: None,
|
||||||
|
connect_callback: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn user_id(&self) -> watch::Receiver<Option<u64>> {
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
self.state.read().user_id.1.clone()
|
pub fn set_login_and_connect_callbacks<Login, Connect>(
|
||||||
|
&mut self,
|
||||||
|
login: Login,
|
||||||
|
connect: Connect,
|
||||||
|
) where
|
||||||
|
Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
|
||||||
|
Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
|
||||||
|
{
|
||||||
|
self.auth_callback = Some(Box::new(login));
|
||||||
|
self.connect_callback = Some(Box::new(connect));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn status(&self) -> watch::Receiver<Status> {
|
||||||
|
self.state.read().status.1.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
|
||||||
|
let mut state = self.state.write();
|
||||||
|
*state.status.0.borrow_mut() = status;
|
||||||
|
|
||||||
|
match status {
|
||||||
|
Status::Connected { .. } => {
|
||||||
|
let heartbeat_interval = state.heartbeat_interval;
|
||||||
|
let this = self.clone();
|
||||||
|
let foreground = cx.foreground();
|
||||||
|
state._maintain_connection = Some(cx.foreground().spawn(async move {
|
||||||
|
loop {
|
||||||
|
foreground.timer(heartbeat_interval).await;
|
||||||
|
this.request(proto::Ping {}).await.unwrap();
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Status::ConnectionLost => {
|
||||||
|
let this = self.clone();
|
||||||
|
let foreground = cx.foreground();
|
||||||
|
let heartbeat_interval = state.heartbeat_interval;
|
||||||
|
state._maintain_connection = Some(cx.spawn(|cx| async move {
|
||||||
|
let mut rng = StdRng::from_entropy();
|
||||||
|
let mut delay = Duration::from_millis(100);
|
||||||
|
while let Err(error) = this.authenticate_and_connect(&cx).await {
|
||||||
|
log::error!("failed to connect {}", error);
|
||||||
|
this.set_status(
|
||||||
|
Status::ReconnectionError {
|
||||||
|
next_reconnection: Instant::now() + delay,
|
||||||
|
},
|
||||||
|
&cx,
|
||||||
|
);
|
||||||
|
foreground.timer(delay).await;
|
||||||
|
delay = delay
|
||||||
|
.mul_f32(rng.gen_range(1.0..=2.0))
|
||||||
|
.min(heartbeat_interval);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Status::Disconnected => {
|
||||||
|
state._maintain_connection.take();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn subscribe_from_model<T, M, F>(
|
pub fn subscribe_from_model<T, M, F>(
|
||||||
|
@ -141,56 +231,57 @@ impl Client {
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
if self.state.read().connection_id.is_some() {
|
let was_disconnected = match *self.status().borrow() {
|
||||||
return Ok(());
|
Status::Disconnected => true,
|
||||||
}
|
Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
|
||||||
|
false
|
||||||
let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
|
}
|
||||||
let user_id = user_id.parse::<u64>()?;
|
Status::Connected { .. }
|
||||||
let request =
|
| Status::Connecting { .. }
|
||||||
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
|
| Status::Reconnecting { .. }
|
||||||
|
| Status::Authenticating
|
||||||
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
|
| Status::Reauthenticating => return Ok(()),
|
||||||
let stream = smol::net::TcpStream::connect(host).await?;
|
|
||||||
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
|
|
||||||
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
|
|
||||||
.await
|
|
||||||
.context("websocket handshake")?;
|
|
||||||
self.add_connection(user_id, stream, cx).await?;
|
|
||||||
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
|
|
||||||
let stream = smol::net::TcpStream::connect(host).await?;
|
|
||||||
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
|
|
||||||
let (stream, _) = async_tungstenite::client_async(request, stream)
|
|
||||||
.await
|
|
||||||
.context("websocket handshake")?;
|
|
||||||
self.add_connection(user_id, stream, cx).await?;
|
|
||||||
} else {
|
|
||||||
return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
if was_disconnected {
|
||||||
Ok(())
|
self.set_status(Status::Authenticating, cx);
|
||||||
|
} else {
|
||||||
|
self.set_status(Status::Reauthenticating, cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
let (user_id, access_token) = match self.authenticate(&cx).await {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(err) => {
|
||||||
|
self.set_status(Status::ConnectionError, cx);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if was_disconnected {
|
||||||
|
self.set_status(Status::Connecting { user_id }, cx);
|
||||||
|
} else {
|
||||||
|
self.set_status(Status::Reconnecting { user_id }, cx);
|
||||||
|
}
|
||||||
|
match self.connect(user_id, &access_token, cx).await {
|
||||||
|
Ok(conn) => {
|
||||||
|
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
|
||||||
|
self.set_connection(user_id, conn, cx).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
self.set_status(Status::ConnectionError, cx);
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add_connection<Conn>(
|
async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
|
||||||
self: &Arc<Self>,
|
|
||||||
user_id: u64,
|
|
||||||
conn: Conn,
|
|
||||||
cx: &AsyncAppContext,
|
|
||||||
) -> anyhow::Result<()>
|
|
||||||
where
|
|
||||||
Conn: 'static
|
|
||||||
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
|
||||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
|
||||||
+ Unpin
|
|
||||||
+ Send,
|
|
||||||
{
|
|
||||||
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
|
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
|
||||||
{
|
cx.foreground()
|
||||||
let mut cx = cx.clone();
|
.spawn({
|
||||||
let this = self.clone();
|
let mut cx = cx.clone();
|
||||||
cx.foreground()
|
let this = self.clone();
|
||||||
.spawn(async move {
|
async move {
|
||||||
while let Some(message) = incoming.recv().await {
|
while let Some(message) = incoming.recv().await {
|
||||||
let mut state = this.state.write();
|
let mut state = this.state.write();
|
||||||
if let Some(extract_entity_id) =
|
if let Some(extract_entity_id) =
|
||||||
|
@ -215,27 +306,90 @@ impl Client {
|
||||||
log::info!("unhandled message {}", message.payload_type_name());
|
log::info!("unhandled message {}", message.payload_type_name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
}
|
|
||||||
cx.background()
|
|
||||||
.spawn(async move {
|
|
||||||
if let Err(error) = handle_io.await {
|
|
||||||
log::error!("connection error: {:?}", error);
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
let mut state = self.state.write();
|
|
||||||
state.connection_id = Some(connection_id);
|
self.set_status(
|
||||||
state.user_id.0.send(Some(user_id)).await?;
|
Status::Connected {
|
||||||
Ok(())
|
connection_id,
|
||||||
|
user_id,
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
|
||||||
|
let handle_io = cx.background().spawn(handle_io);
|
||||||
|
let this = self.clone();
|
||||||
|
let cx = cx.clone();
|
||||||
|
cx.foreground()
|
||||||
|
.spawn(async move {
|
||||||
|
match handle_io.await {
|
||||||
|
Ok(()) => this.set_status(Status::Disconnected, &cx),
|
||||||
|
Err(err) => {
|
||||||
|
log::error!("connection error: {:?}", err);
|
||||||
|
this.set_status(Status::ConnectionLost, &cx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn login(
|
fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
|
||||||
platform: Arc<dyn gpui::Platform>,
|
if let Some(callback) = self.auth_callback.as_ref() {
|
||||||
executor: &Arc<gpui::executor::Background>,
|
callback(cx)
|
||||||
) -> Task<Result<(String, String)>> {
|
} else {
|
||||||
let executor = executor.clone();
|
self.authenticate_with_browser(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn connect(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
user_id: u64,
|
||||||
|
access_token: &str,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> Task<Result<Conn>> {
|
||||||
|
if let Some(callback) = self.connect_callback.as_ref() {
|
||||||
|
callback(user_id, access_token, cx)
|
||||||
|
} else {
|
||||||
|
self.connect_with_websocket(user_id, access_token, cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn connect_with_websocket(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
user_id: u64,
|
||||||
|
access_token: &str,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> Task<Result<Conn>> {
|
||||||
|
let request =
|
||||||
|
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
|
||||||
|
cx.background().spawn(async move {
|
||||||
|
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
|
||||||
|
let stream = smol::net::TcpStream::connect(host).await?;
|
||||||
|
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
|
||||||
|
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
|
||||||
|
.await
|
||||||
|
.context("websocket handshake")?;
|
||||||
|
Ok(Conn::new(stream))
|
||||||
|
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
|
||||||
|
let stream = smol::net::TcpStream::connect(host).await?;
|
||||||
|
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
|
||||||
|
let (stream, _) = async_tungstenite::client_async(request, stream)
|
||||||
|
.await
|
||||||
|
.context("websocket handshake")?;
|
||||||
|
Ok(Conn::new(stream))
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authenticate_with_browser(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
cx: &AsyncAppContext,
|
||||||
|
) -> Task<Result<(u64, String)>> {
|
||||||
|
let platform = cx.platform();
|
||||||
|
let executor = cx.background();
|
||||||
executor.clone().spawn(async move {
|
executor.clone().spawn(async move {
|
||||||
if let Some((user_id, access_token)) = platform
|
if let Some((user_id, access_token)) = platform
|
||||||
.read_credentials(&ZED_SERVER_URL)
|
.read_credentials(&ZED_SERVER_URL)
|
||||||
|
@ -243,7 +397,7 @@ impl Client {
|
||||||
.flatten()
|
.flatten()
|
||||||
{
|
{
|
||||||
log::info!("already signed in. user_id: {}", user_id);
|
log::info!("already signed in. user_id: {}", user_id);
|
||||||
return Ok((user_id, String::from_utf8(access_token).unwrap()));
|
return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a pair of asymmetric encryption keys. The public key will be used by the
|
// Generate a pair of asymmetric encryption keys. The public key will be used by the
|
||||||
|
@ -309,21 +463,23 @@ impl Client {
|
||||||
platform
|
platform
|
||||||
.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
|
.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
|
||||||
.log_err();
|
.log_err();
|
||||||
Ok((user_id.to_string(), access_token))
|
Ok((user_id.parse()?, access_token))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn disconnect(&self) -> Result<()> {
|
pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
|
||||||
let conn_id = self.connection_id()?;
|
let conn_id = self.connection_id()?;
|
||||||
self.peer.disconnect(conn_id).await;
|
self.peer.disconnect(conn_id).await;
|
||||||
|
self.set_status(Status::Disconnected, cx);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn connection_id(&self) -> Result<ConnectionId> {
|
fn connection_id(&self) -> Result<ConnectionId> {
|
||||||
self.state
|
if let Status::Connected { connection_id, .. } = *self.status().borrow() {
|
||||||
.read()
|
Ok(connection_id)
|
||||||
.connection_id
|
} else {
|
||||||
.ok_or_else(|| anyhow!("not connected"))
|
Err(anyhow!("not connected"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
|
pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
|
||||||
|
@ -343,35 +499,6 @@ impl Client {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
|
|
||||||
type Output: 'a + Future<Output = anyhow::Result<()>>;
|
|
||||||
|
|
||||||
fn handle(
|
|
||||||
&self,
|
|
||||||
message: TypedEnvelope<M>,
|
|
||||||
rpc: &'a Client,
|
|
||||||
cx: &'a mut gpui::AsyncAppContext,
|
|
||||||
) -> Self::Output;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, M, F, Fut> MessageHandler<'a, M> for F
|
|
||||||
where
|
|
||||||
M: proto::EnvelopedMessage,
|
|
||||||
F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
|
|
||||||
Fut: 'a + Future<Output = anyhow::Result<()>>,
|
|
||||||
{
|
|
||||||
type Output = Fut;
|
|
||||||
|
|
||||||
fn handle(
|
|
||||||
&self,
|
|
||||||
message: TypedEnvelope<M>,
|
|
||||||
rpc: &'a Client,
|
|
||||||
cx: &'a mut gpui::AsyncAppContext,
|
|
||||||
) -> Self::Output {
|
|
||||||
(self)(message, rpc, cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
|
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
|
||||||
|
|
||||||
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
|
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
|
||||||
|
@ -396,13 +523,62 @@ const LOGIN_RESPONSE: &'static str = "
|
||||||
</html>
|
</html>
|
||||||
";
|
";
|
||||||
|
|
||||||
#[test]
|
#[cfg(test)]
|
||||||
fn test_encode_and_decode_worktree_url() {
|
mod tests {
|
||||||
let url = encode_worktree_url(5, "deadbeef");
|
use super::*;
|
||||||
assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
|
use crate::test::FakeServer;
|
||||||
assert_eq!(
|
use gpui::TestAppContext;
|
||||||
decode_worktree_url(&format!("\n {}\t", url)),
|
|
||||||
Some((5, "deadbeef".to_string()))
|
#[gpui::test(iterations = 10)]
|
||||||
);
|
async fn test_heartbeat(cx: TestAppContext) {
|
||||||
assert_eq!(decode_worktree_url("not://the-right-format"), None);
|
cx.foreground().forbid_parking();
|
||||||
|
|
||||||
|
let user_id = 5;
|
||||||
|
let mut client = Client::new();
|
||||||
|
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||||
|
|
||||||
|
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||||
|
let ping = server.receive::<proto::Ping>().await.unwrap();
|
||||||
|
server.respond(ping.receipt(), proto::Ack {}).await;
|
||||||
|
|
||||||
|
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||||
|
let ping = server.receive::<proto::Ping>().await.unwrap();
|
||||||
|
server.respond(ping.receipt(), proto::Ack {}).await;
|
||||||
|
|
||||||
|
client.disconnect(&cx.to_async()).await.unwrap();
|
||||||
|
assert!(server.receive::<proto::Ping>().await.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test(iterations = 10)]
|
||||||
|
async fn test_reconnection(cx: TestAppContext) {
|
||||||
|
cx.foreground().forbid_parking();
|
||||||
|
|
||||||
|
let user_id = 5;
|
||||||
|
let mut client = Client::new();
|
||||||
|
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
|
||||||
|
let mut status = client.status();
|
||||||
|
assert!(matches!(
|
||||||
|
status.recv().await,
|
||||||
|
Some(Status::Connected { .. })
|
||||||
|
));
|
||||||
|
|
||||||
|
server.forbid_connections();
|
||||||
|
server.disconnect().await;
|
||||||
|
while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
|
||||||
|
|
||||||
|
server.allow_connections();
|
||||||
|
cx.foreground().advance_clock(Duration::from_secs(10));
|
||||||
|
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_and_decode_worktree_url() {
|
||||||
|
let url = encode_worktree_url(5, "deadbeef");
|
||||||
|
assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
|
||||||
|
assert_eq!(
|
||||||
|
decode_worktree_url(&format!("\n {}\t", url)),
|
||||||
|
Some((5, "deadbeef".to_string()))
|
||||||
|
);
|
||||||
|
assert_eq!(decode_worktree_url("not://the-right-format"), None);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
129
zed/src/test.rs
129
zed/src/test.rs
|
@ -3,24 +3,27 @@ use crate::{
|
||||||
channel::ChannelList,
|
channel::ChannelList,
|
||||||
fs::RealFs,
|
fs::RealFs,
|
||||||
language::LanguageRegistry,
|
language::LanguageRegistry,
|
||||||
rpc,
|
rpc::{self, Client},
|
||||||
settings::{self, ThemeRegistry},
|
settings::{self, ThemeRegistry},
|
||||||
time::ReplicaId,
|
time::ReplicaId,
|
||||||
user::UserStore,
|
user::UserStore,
|
||||||
AppState,
|
AppState,
|
||||||
};
|
};
|
||||||
use gpui::{Entity, ModelHandle, MutableAppContext};
|
use anyhow::{anyhow, Result};
|
||||||
|
use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
use postage::{mpsc, prelude::Stream as _};
|
||||||
use smol::channel;
|
use smol::channel;
|
||||||
use std::{
|
use std::{
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
sync::Arc,
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering::SeqCst},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use tempdir::TempDir;
|
use tempdir::TempDir;
|
||||||
|
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
|
||||||
#[cfg(feature = "test-support")]
|
|
||||||
pub use zrpc::test::Channel;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[ctor::ctor]
|
#[ctor::ctor]
|
||||||
|
@ -195,3 +198,117 @@ impl<T: Entity> Observer<T> {
|
||||||
(observer, notify_rx)
|
(observer, notify_rx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct FakeServer {
|
||||||
|
peer: Arc<Peer>,
|
||||||
|
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
|
||||||
|
connection_id: Mutex<Option<ConnectionId>>,
|
||||||
|
forbid_connections: AtomicBool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeServer {
|
||||||
|
pub async fn for_client(
|
||||||
|
client_user_id: u64,
|
||||||
|
client: &mut Arc<Client>,
|
||||||
|
cx: &TestAppContext,
|
||||||
|
) -> Arc<Self> {
|
||||||
|
let result = Arc::new(Self {
|
||||||
|
peer: Peer::new(),
|
||||||
|
incoming: Default::default(),
|
||||||
|
connection_id: Default::default(),
|
||||||
|
forbid_connections: Default::default(),
|
||||||
|
});
|
||||||
|
|
||||||
|
Arc::get_mut(client)
|
||||||
|
.unwrap()
|
||||||
|
.set_login_and_connect_callbacks(
|
||||||
|
move |cx| {
|
||||||
|
cx.spawn(|_| async move {
|
||||||
|
let access_token = "the-token".to_string();
|
||||||
|
Ok((client_user_id, access_token))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
{
|
||||||
|
let server = result.clone();
|
||||||
|
move |user_id, access_token, cx| {
|
||||||
|
assert_eq!(user_id, client_user_id);
|
||||||
|
assert_eq!(access_token, "the-token");
|
||||||
|
cx.spawn({
|
||||||
|
let server = server.clone();
|
||||||
|
move |cx| async move { server.connect(&cx).await }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
client
|
||||||
|
.authenticate_and_connect(&cx.to_async())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn disconnect(&self) {
|
||||||
|
self.peer.disconnect(self.connection_id()).await;
|
||||||
|
self.connection_id.lock().take();
|
||||||
|
self.incoming.lock().take();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
|
||||||
|
if self.forbid_connections.load(SeqCst) {
|
||||||
|
Err(anyhow!("server is forbidding connections"))
|
||||||
|
} else {
|
||||||
|
let (client_conn, server_conn, _) = Conn::in_memory();
|
||||||
|
let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
|
||||||
|
cx.background().spawn(io).detach();
|
||||||
|
*self.incoming.lock() = Some(incoming);
|
||||||
|
*self.connection_id.lock() = Some(connection_id);
|
||||||
|
Ok(client_conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forbid_connections(&self) {
|
||||||
|
self.forbid_connections.store(true, SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allow_connections(&self) {
|
||||||
|
self.forbid_connections.store(false, SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
|
||||||
|
self.peer.send(self.connection_id(), message).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
|
||||||
|
let message = self
|
||||||
|
.incoming
|
||||||
|
.lock()
|
||||||
|
.as_mut()
|
||||||
|
.expect("not connected")
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| anyhow!("other half hung up"))?;
|
||||||
|
let type_name = message.payload_type_name();
|
||||||
|
Ok(*message
|
||||||
|
.into_any()
|
||||||
|
.downcast::<TypedEnvelope<M>>()
|
||||||
|
.unwrap_or_else(|_| {
|
||||||
|
panic!(
|
||||||
|
"fake server received unexpected message type: {:?}",
|
||||||
|
type_name
|
||||||
|
);
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn respond<T: proto::RequestMessage>(
|
||||||
|
&self,
|
||||||
|
receipt: Receipt<T>,
|
||||||
|
response: T::Response,
|
||||||
|
) {
|
||||||
|
self.peer.respond(receipt, response).await.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn connection_id(&self) -> ConnectionId {
|
||||||
|
self.connection_id.lock().expect("not connected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -234,6 +234,7 @@ impl Worktree {
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|p| (PeerId(p.peer_id), p.replica_id as ReplicaId))
|
.map(|p| (PeerId(p.peer_id), p.replica_id as ReplicaId))
|
||||||
.collect(),
|
.collect(),
|
||||||
|
queued_operations: Default::default(),
|
||||||
languages,
|
languages,
|
||||||
_subscriptions,
|
_subscriptions,
|
||||||
})
|
})
|
||||||
|
@ -656,6 +657,7 @@ pub struct LocalWorktree {
|
||||||
shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
|
shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
|
||||||
peers: HashMap<PeerId, ReplicaId>,
|
peers: HashMap<PeerId, ReplicaId>,
|
||||||
languages: Arc<LanguageRegistry>,
|
languages: Arc<LanguageRegistry>,
|
||||||
|
queued_operations: Vec<(u64, Operation)>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -711,6 +713,7 @@ impl LocalWorktree {
|
||||||
poll_task: None,
|
poll_task: None,
|
||||||
open_buffers: Default::default(),
|
open_buffers: Default::default(),
|
||||||
shared_buffers: Default::default(),
|
shared_buffers: Default::default(),
|
||||||
|
queued_operations: Default::default(),
|
||||||
peers: Default::default(),
|
peers: Default::default(),
|
||||||
languages,
|
languages,
|
||||||
fs,
|
fs,
|
||||||
|
@ -1091,6 +1094,7 @@ pub struct RemoteWorktree {
|
||||||
open_buffers: HashMap<usize, RemoteBuffer>,
|
open_buffers: HashMap<usize, RemoteBuffer>,
|
||||||
peers: HashMap<PeerId, ReplicaId>,
|
peers: HashMap<PeerId, ReplicaId>,
|
||||||
languages: Arc<LanguageRegistry>,
|
languages: Arc<LanguageRegistry>,
|
||||||
|
queued_operations: Vec<(u64, Operation)>,
|
||||||
_subscriptions: Vec<rpc::Subscription>,
|
_subscriptions: Vec<rpc::Subscription>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1550,16 +1554,23 @@ impl File {
|
||||||
.map(|share| (share.rpc.clone(), share.remote_id)),
|
.map(|share| (share.rpc.clone(), share.remote_id)),
|
||||||
Worktree::Remote(worktree) => Some((worktree.rpc.clone(), worktree.remote_id)),
|
Worktree::Remote(worktree) => Some((worktree.rpc.clone(), worktree.remote_id)),
|
||||||
} {
|
} {
|
||||||
cx.spawn(|_, _| async move {
|
cx.spawn(|worktree, mut cx| async move {
|
||||||
if let Err(error) = rpc
|
if let Err(error) = rpc
|
||||||
.send(proto::UpdateBuffer {
|
.request(proto::UpdateBuffer {
|
||||||
worktree_id: remote_id,
|
worktree_id: remote_id,
|
||||||
buffer_id,
|
buffer_id,
|
||||||
operations: Some(operation).iter().map(Into::into).collect(),
|
operations: vec![(&operation).into()],
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
log::error!("error sending buffer operation: {}", error);
|
worktree.update(&mut cx, |worktree, _| {
|
||||||
|
log::error!("error sending buffer operation: {}", error);
|
||||||
|
match worktree {
|
||||||
|
Worktree::Local(t) => &mut t.queued_operations,
|
||||||
|
Worktree::Remote(t) => &mut t.queued_operations,
|
||||||
|
}
|
||||||
|
.push((buffer_id, operation));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
|
@ -1582,7 +1593,7 @@ impl File {
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
log::error!("error closing remote buffer: {}", error);
|
log::error!("error closing remote buffer: {}", error);
|
||||||
};
|
}
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,9 +6,9 @@ message Envelope {
|
||||||
optional uint32 responding_to = 2;
|
optional uint32 responding_to = 2;
|
||||||
optional uint32 original_sender_id = 3;
|
optional uint32 original_sender_id = 3;
|
||||||
oneof payload {
|
oneof payload {
|
||||||
Error error = 4;
|
Ack ack = 4;
|
||||||
Ping ping = 5;
|
Error error = 5;
|
||||||
Pong pong = 6;
|
Ping ping = 6;
|
||||||
ShareWorktree share_worktree = 7;
|
ShareWorktree share_worktree = 7;
|
||||||
ShareWorktreeResponse share_worktree_response = 8;
|
ShareWorktreeResponse share_worktree_response = 8;
|
||||||
OpenWorktree open_worktree = 9;
|
OpenWorktree open_worktree = 9;
|
||||||
|
@ -40,13 +40,9 @@ message Envelope {
|
||||||
|
|
||||||
// Messages
|
// Messages
|
||||||
|
|
||||||
message Ping {
|
message Ping {}
|
||||||
int32 id = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Pong {
|
message Ack {}
|
||||||
int32 id = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Error {
|
message Error {
|
||||||
string message = 1;
|
string message = 1;
|
||||||
|
|
101
zrpc/src/conn.rs
Normal file
101
zrpc/src/conn.rs
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
||||||
|
use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
|
||||||
|
use std::{io, task::Poll};
|
||||||
|
|
||||||
|
pub struct Conn {
|
||||||
|
pub(crate) tx:
|
||||||
|
Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
|
||||||
|
pub(crate) rx: Box<
|
||||||
|
dyn 'static
|
||||||
|
+ Send
|
||||||
|
+ Unpin
|
||||||
|
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
|
||||||
|
>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Conn {
|
||||||
|
pub fn new<S>(stream: S) -> Self
|
||||||
|
where
|
||||||
|
S: 'static
|
||||||
|
+ Send
|
||||||
|
+ Unpin
|
||||||
|
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
|
||||||
|
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
|
||||||
|
{
|
||||||
|
let (tx, rx) = stream.split();
|
||||||
|
Self {
|
||||||
|
tx: Box::new(tx),
|
||||||
|
rx: Box::new(rx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
|
||||||
|
self.tx.send(message).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
|
||||||
|
let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
|
||||||
|
postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
|
||||||
|
|
||||||
|
let (a_tx, a_rx) = Self::channel(kill_rx.clone());
|
||||||
|
let (b_tx, b_rx) = Self::channel(kill_rx);
|
||||||
|
(
|
||||||
|
Self { tx: a_tx, rx: b_rx },
|
||||||
|
Self { tx: b_tx, rx: a_rx },
|
||||||
|
kill_tx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
fn channel(
|
||||||
|
kill_rx: postage::watch::Receiver<Option<()>>,
|
||||||
|
) -> (
|
||||||
|
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
|
||||||
|
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
|
||||||
|
) {
|
||||||
|
use futures::{future, SinkExt as _};
|
||||||
|
use io::{Error, ErrorKind};
|
||||||
|
|
||||||
|
let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
|
||||||
|
let tx = tx
|
||||||
|
.sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
|
||||||
|
.with({
|
||||||
|
let kill_rx = kill_rx.clone();
|
||||||
|
move |msg| {
|
||||||
|
if kill_rx.borrow().is_none() {
|
||||||
|
future::ready(Ok(msg))
|
||||||
|
} else {
|
||||||
|
future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let rx = KillableReceiver { kill_rx, rx };
|
||||||
|
|
||||||
|
(Box::new(tx), Box::new(rx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct KillableReceiver {
|
||||||
|
rx: mpsc::UnboundedReceiver<WebSocketMessage>,
|
||||||
|
kill_rx: postage::watch::Receiver<Option<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for KillableReceiver {
|
||||||
|
type Item = Result<WebSocketMessage, WebSocketError>;
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
mut self: std::pin::Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Option<Self::Item>> {
|
||||||
|
if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
|
||||||
|
Poll::Ready(Some(Err(io::Error::new(
|
||||||
|
io::ErrorKind::Other,
|
||||||
|
"connection killed",
|
||||||
|
)
|
||||||
|
.into())))
|
||||||
|
} else {
|
||||||
|
self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,7 +1,6 @@
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
|
mod conn;
|
||||||
mod peer;
|
mod peer;
|
||||||
pub mod proto;
|
pub mod proto;
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
pub use conn::Conn;
|
||||||
pub mod test;
|
|
||||||
|
|
||||||
pub use peer::*;
|
pub use peer::*;
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
|
||||||
|
use super::Conn;
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use async_lock::{Mutex, RwLock};
|
use async_lock::{Mutex, RwLock};
|
||||||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
use futures::FutureExt as _;
|
||||||
use futures::{FutureExt, StreamExt};
|
|
||||||
use postage::{
|
use postage::{
|
||||||
mpsc,
|
mpsc,
|
||||||
prelude::{Sink as _, Stream as _},
|
prelude::{Sink as _, Stream as _},
|
||||||
|
@ -98,21 +98,14 @@ impl Peer {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn add_connection<Conn>(
|
pub async fn add_connection(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
conn: Conn,
|
conn: Conn,
|
||||||
) -> (
|
) -> (
|
||||||
ConnectionId,
|
ConnectionId,
|
||||||
impl Future<Output = anyhow::Result<()>> + Send,
|
impl Future<Output = anyhow::Result<()>> + Send,
|
||||||
mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
|
mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
|
||||||
)
|
) {
|
||||||
where
|
|
||||||
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
|
|
||||||
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
|
|
||||||
+ Send
|
|
||||||
+ Unpin,
|
|
||||||
{
|
|
||||||
let (tx, rx) = conn.split();
|
|
||||||
let connection_id = ConnectionId(
|
let connection_id = ConnectionId(
|
||||||
self.next_connection_id
|
self.next_connection_id
|
||||||
.fetch_add(1, atomic::Ordering::SeqCst),
|
.fetch_add(1, atomic::Ordering::SeqCst),
|
||||||
|
@ -124,9 +117,10 @@ impl Peer {
|
||||||
next_message_id: Default::default(),
|
next_message_id: Default::default(),
|
||||||
response_channels: Default::default(),
|
response_channels: Default::default(),
|
||||||
};
|
};
|
||||||
let mut writer = MessageStream::new(tx);
|
let mut writer = MessageStream::new(conn.tx);
|
||||||
let mut reader = MessageStream::new(rx);
|
let mut reader = MessageStream::new(conn.rx);
|
||||||
|
|
||||||
|
let this = self.clone();
|
||||||
let response_channels = connection.response_channels.clone();
|
let response_channels = connection.response_channels.clone();
|
||||||
let handle_io = async move {
|
let handle_io = async move {
|
||||||
loop {
|
loop {
|
||||||
|
@ -147,6 +141,7 @@ impl Peer {
|
||||||
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
|
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
|
||||||
if incoming_tx.send(envelope).await.is_err() {
|
if incoming_tx.send(envelope).await.is_err() {
|
||||||
response_channels.lock().await.clear();
|
response_channels.lock().await.clear();
|
||||||
|
this.connections.write().await.remove(&connection_id);
|
||||||
return Ok(())
|
return Ok(())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -158,6 +153,7 @@ impl Peer {
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
response_channels.lock().await.clear();
|
response_channels.lock().await.clear();
|
||||||
|
this.connections.write().await.remove(&connection_id);
|
||||||
Err(error).context("received invalid RPC message")?;
|
Err(error).context("received invalid RPC message")?;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -165,11 +161,13 @@ impl Peer {
|
||||||
Some(outgoing) => {
|
Some(outgoing) => {
|
||||||
if let Err(result) = writer.write_message(&outgoing).await {
|
if let Err(result) = writer.write_message(&outgoing).await {
|
||||||
response_channels.lock().await.clear();
|
response_channels.lock().await.clear();
|
||||||
|
this.connections.write().await.remove(&connection_id);
|
||||||
Err(result).context("failed to write RPC message")?;
|
Err(result).context("failed to write RPC message")?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
response_channels.lock().await.clear();
|
response_channels.lock().await.clear();
|
||||||
|
this.connections.write().await.remove(&connection_id);
|
||||||
return Ok(())
|
return Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -342,7 +340,9 @@ impl Peer {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{test, TypedEnvelope};
|
use crate::TypedEnvelope;
|
||||||
|
use async_tungstenite::tungstenite::Message as WebSocketMessage;
|
||||||
|
use futures::StreamExt as _;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_request_response() {
|
fn test_request_response() {
|
||||||
|
@ -352,12 +352,12 @@ mod tests {
|
||||||
let client1 = Peer::new();
|
let client1 = Peer::new();
|
||||||
let client2 = Peer::new();
|
let client2 = Peer::new();
|
||||||
|
|
||||||
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
|
let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
|
||||||
let (client1_conn_id, io_task1, _) =
|
let (client1_conn_id, io_task1, _) =
|
||||||
client1.add_connection(client1_to_server_conn).await;
|
client1.add_connection(client1_to_server_conn).await;
|
||||||
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
|
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
|
||||||
|
|
||||||
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
|
let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
|
||||||
let (client2_conn_id, io_task3, _) =
|
let (client2_conn_id, io_task3, _) =
|
||||||
client2.add_connection(client2_to_server_conn).await;
|
client2.add_connection(client2_to_server_conn).await;
|
||||||
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
|
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
|
||||||
|
@ -371,18 +371,18 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
client1
|
client1
|
||||||
.request(client1_conn_id, proto::Ping { id: 1 },)
|
.request(client1_conn_id, proto::Ping {},)
|
||||||
.await
|
.await
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
proto::Pong { id: 1 }
|
proto::Ack {}
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
client2
|
client2
|
||||||
.request(client2_conn_id, proto::Ping { id: 2 },)
|
.request(client2_conn_id, proto::Ping {},)
|
||||||
.await
|
.await
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
proto::Pong { id: 2 }
|
proto::Ack {}
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -438,13 +438,7 @@ mod tests {
|
||||||
let envelope = envelope.into_any();
|
let envelope = envelope.into_any();
|
||||||
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
|
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
|
||||||
let receipt = envelope.receipt();
|
let receipt = envelope.receipt();
|
||||||
peer.respond(
|
peer.respond(receipt, proto::Ack {}).await?
|
||||||
receipt,
|
|
||||||
proto::Pong {
|
|
||||||
id: envelope.payload.id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
} else if let Some(envelope) =
|
} else if let Some(envelope) =
|
||||||
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
|
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
|
||||||
{
|
{
|
||||||
|
@ -492,7 +486,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_disconnect() {
|
fn test_disconnect() {
|
||||||
smol::block_on(async move {
|
smol::block_on(async move {
|
||||||
let (client_conn, mut server_conn) = test::Channel::bidirectional();
|
let (client_conn, mut server_conn, _) = Conn::in_memory();
|
||||||
|
|
||||||
let client = Peer::new();
|
let client = Peer::new();
|
||||||
let (connection_id, io_handler, mut incoming) =
|
let (connection_id, io_handler, mut incoming) =
|
||||||
|
@ -516,18 +510,17 @@ mod tests {
|
||||||
|
|
||||||
io_ended_rx.recv().await;
|
io_ended_rx.recv().await;
|
||||||
messages_ended_rx.recv().await;
|
messages_ended_rx.recv().await;
|
||||||
assert!(
|
assert!(server_conn
|
||||||
futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
|
.send(WebSocketMessage::Binary(vec![]))
|
||||||
.await
|
.await
|
||||||
.is_err()
|
.is_err());
|
||||||
);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_io_error() {
|
fn test_io_error() {
|
||||||
smol::block_on(async move {
|
smol::block_on(async move {
|
||||||
let (client_conn, server_conn) = test::Channel::bidirectional();
|
let (client_conn, server_conn, _) = Conn::in_memory();
|
||||||
drop(server_conn);
|
drop(server_conn);
|
||||||
|
|
||||||
let client = Peer::new();
|
let client = Peer::new();
|
||||||
|
@ -537,7 +530,7 @@ mod tests {
|
||||||
smol::spawn(async move { incoming.next().await }).detach();
|
smol::spawn(async move { incoming.next().await }).detach();
|
||||||
|
|
||||||
let err = client
|
let err = client
|
||||||
.request(connection_id, proto::Ping { id: 42 })
|
.request(connection_id, proto::Ping {})
|
||||||
.await
|
.await
|
||||||
.unwrap_err();
|
.unwrap_err();
|
||||||
assert_eq!(err.to_string(), "connection was closed");
|
assert_eq!(err.to_string(), "connection was closed");
|
||||||
|
|
|
@ -120,6 +120,7 @@ macro_rules! entity_messages {
|
||||||
}
|
}
|
||||||
|
|
||||||
messages!(
|
messages!(
|
||||||
|
Ack,
|
||||||
AddPeer,
|
AddPeer,
|
||||||
BufferSaved,
|
BufferSaved,
|
||||||
ChannelMessageSent,
|
ChannelMessageSent,
|
||||||
|
@ -140,7 +141,6 @@ messages!(
|
||||||
OpenWorktree,
|
OpenWorktree,
|
||||||
OpenWorktreeResponse,
|
OpenWorktreeResponse,
|
||||||
Ping,
|
Ping,
|
||||||
Pong,
|
|
||||||
RemovePeer,
|
RemovePeer,
|
||||||
SaveBuffer,
|
SaveBuffer,
|
||||||
SendChannelMessage,
|
SendChannelMessage,
|
||||||
|
@ -157,8 +157,9 @@ request_messages!(
|
||||||
(JoinChannel, JoinChannelResponse),
|
(JoinChannel, JoinChannelResponse),
|
||||||
(OpenBuffer, OpenBufferResponse),
|
(OpenBuffer, OpenBufferResponse),
|
||||||
(OpenWorktree, OpenWorktreeResponse),
|
(OpenWorktree, OpenWorktreeResponse),
|
||||||
(Ping, Pong),
|
(Ping, Ack),
|
||||||
(SaveBuffer, BufferSaved),
|
(SaveBuffer, BufferSaved),
|
||||||
|
(UpdateBuffer, Ack),
|
||||||
(ShareWorktree, ShareWorktreeResponse),
|
(ShareWorktree, ShareWorktreeResponse),
|
||||||
(SendChannelMessage, SendChannelMessageResponse),
|
(SendChannelMessage, SendChannelMessageResponse),
|
||||||
(GetChannelMessages, GetChannelMessagesResponse),
|
(GetChannelMessages, GetChannelMessagesResponse),
|
||||||
|
@ -247,30 +248,3 @@ impl From<SystemTime> for Timestamp {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::test;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_round_trip_message() {
|
|
||||||
smol::block_on(async {
|
|
||||||
let stream = test::Channel::new();
|
|
||||||
let message1 = Ping { id: 5 }.into_envelope(3, None, None);
|
|
||||||
let message2 = OpenBuffer {
|
|
||||||
worktree_id: 0,
|
|
||||||
path: "some/path".to_string(),
|
|
||||||
}
|
|
||||||
.into_envelope(5, None, None);
|
|
||||||
|
|
||||||
let mut message_stream = MessageStream::new(stream);
|
|
||||||
message_stream.write_message(&message1).await.unwrap();
|
|
||||||
message_stream.write_message(&message2).await.unwrap();
|
|
||||||
let decoded_message1 = message_stream.read_message().await.unwrap();
|
|
||||||
let decoded_message2 = message_stream.read_message().await.unwrap();
|
|
||||||
assert_eq!(decoded_message1, message1);
|
|
||||||
assert_eq!(decoded_message2, message2);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,64 +0,0 @@
|
||||||
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
|
|
||||||
use std::{
|
|
||||||
io,
|
|
||||||
pin::Pin,
|
|
||||||
task::{Context, Poll},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub struct Channel {
|
|
||||||
tx: futures::channel::mpsc::UnboundedSender<WebSocketMessage>,
|
|
||||||
rx: futures::channel::mpsc::UnboundedReceiver<WebSocketMessage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Channel {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
|
||||||
Self { tx, rx }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn bidirectional() -> (Self, Self) {
|
|
||||||
let (a_tx, a_rx) = futures::channel::mpsc::unbounded();
|
|
||||||
let (b_tx, b_rx) = futures::channel::mpsc::unbounded();
|
|
||||||
let a = Self { tx: a_tx, rx: b_rx };
|
|
||||||
let b = Self { tx: b_tx, rx: a_rx };
|
|
||||||
(a, b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl futures::Sink<WebSocketMessage> for Channel {
|
|
||||||
type Error = WebSocketError;
|
|
||||||
|
|
||||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
Pin::new(&mut self.tx)
|
|
||||||
.poll_ready(cx)
|
|
||||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> {
|
|
||||||
Pin::new(&mut self.tx)
|
|
||||||
.start_send(item)
|
|
||||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
Pin::new(&mut self.tx)
|
|
||||||
.poll_flush(cx)
|
|
||||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
Pin::new(&mut self.tx)
|
|
||||||
.poll_close(cx)
|
|
||||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl futures::Stream for Channel {
|
|
||||||
type Item = Result<WebSocketMessage, WebSocketError>;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
||||||
Pin::new(&mut self.rx)
|
|
||||||
.poll_next(cx)
|
|
||||||
.map(|i| i.map(|i| Ok(i)))
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue