This commit is contained in:
Antonio Scandurra 2023-10-21 18:30:44 +02:00
parent aa3fb28f81
commit 7bb99c9b9c
3 changed files with 109 additions and 12 deletions

View file

@ -1,7 +1,10 @@
use crate::{AppContext, PlatformDispatcher};
use futures::channel::mpsc;
use smol::prelude::*;
use std::{
fmt::Debug,
marker::PhantomData,
mem,
pin::Pin,
sync::Arc,
task::{Context, Poll},
@ -133,7 +136,73 @@ impl Executor {
futures::executor::block_on(future)
}
pub async fn scoped<'scope, F>(&self, scheduler: F)
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone());
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn(f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
}
}
pub fn is_main_thread(&self) -> bool {
self.dispatcher.is_main_thread()
}
}
pub struct Scope<'a> {
executor: Executor,
futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
tx: Option<mpsc::Sender<()>>,
rx: mpsc::Receiver<()>,
lifetime: PhantomData<&'a ()>,
}
impl<'a> Scope<'a> {
fn new(executor: Executor) -> Self {
let (tx, rx) = mpsc::channel(1);
Self {
executor,
tx: Some(tx),
rx,
futures: Default::default(),
lifetime: PhantomData,
}
}
pub fn spawn<F>(&mut self, f: F)
where
F: Future<Output = ()> + Send + 'a,
{
let tx = self.tx.clone().unwrap();
// Safety: The 'a lifetime is guaranteed to outlive any of these futures because
// dropping this `Scope` blocks until all of the futures have resolved.
let f = unsafe {
mem::transmute::<
Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
>(Box::pin(async move {
f.await;
drop(tx);
}))
};
self.futures.push(f);
}
}
impl<'a> Drop for Scope<'a> {
fn drop(&mut self) {
self.tx.take().unwrap();
// Wait until the channel is closed, which means that all of the spawned
// futures have resolved.
self.executor.block(self.rx.next());
}
}