use crate::{App, PlatformDispatcher}; use async_task::Runnable; use futures::channel::mpsc; use smol::prelude::*; use std::mem::ManuallyDrop; use std::panic::Location; use std::thread::{self, ThreadId}; use std::{ fmt::Debug, marker::PhantomData, mem, num::NonZeroUsize, pin::Pin, rc::Rc, sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, Arc, }, task::{Context, Poll}, time::{Duration, Instant}, }; use util::TryFutureExt; use waker_fn::waker_fn; #[cfg(any(test, feature = "test-support"))] use rand::rngs::StdRng; /// A pointer to the executor that is currently running, /// for spawning background tasks. #[derive(Clone)] pub struct BackgroundExecutor { #[doc(hidden)] pub dispatcher: Arc, } /// A pointer to the executor that is currently running, /// for spawning tasks on the main thread. #[derive(Clone)] pub struct ForegroundExecutor { #[doc(hidden)] pub dispatcher: Arc, not_send: PhantomData>, } /// Task is a primitive that allows work to happen in the background. /// /// It implements [`Future`] so you can `.await` on it. /// /// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows /// the task to continue running, but with no way to return a value. #[must_use] #[derive(Debug)] pub struct Task(TaskState); #[derive(Debug)] enum TaskState { /// A task that is ready to return a value Ready(Option), /// A task that is currently running. Spawned(async_task::Task), } impl Task { /// Creates a new task that will resolve with the value pub fn ready(val: T) -> Self { Task(TaskState::Ready(Some(val))) } /// Detaching a task runs it to completion in the background pub fn detach(self) { match self { Task(TaskState::Ready(_)) => {} Task(TaskState::Spawned(task)) => task.detach(), } } } impl Task> where T: 'static, E: 'static + Debug, { /// Run the task to completion in the background and log any /// errors that occur. #[track_caller] pub fn detach_and_log_err(self, cx: &App) { let location = core::panic::Location::caller(); cx.foreground_executor() .spawn(self.log_tracked_err(*location)) .detach(); } } impl Future for Task { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { match unsafe { self.get_unchecked_mut() } { Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()), Task(TaskState::Spawned(task)) => task.poll(cx), } } } /// A task label is an opaque identifier that you can use to /// refer to a task in tests. #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct TaskLabel(NonZeroUsize); impl Default for TaskLabel { fn default() -> Self { Self::new() } } impl TaskLabel { /// Construct a new task label. pub fn new() -> Self { static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1); Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap()) } } type AnyLocalFuture = Pin>>; type AnyFuture = Pin>>; /// BackgroundExecutor lets you run things on background threads. /// In production this is a thread pool with no ordering guarantees. /// In tests this is simulated by running tasks one by one in a deterministic /// (but arbitrary) order controlled by the `SEED` environment variable. impl BackgroundExecutor { #[doc(hidden)] pub fn new(dispatcher: Arc) -> Self { Self { dispatcher } } /// Enqueues the given future to be run to completion on a background thread. pub fn spawn(&self, future: impl Future + Send + 'static) -> Task where R: Send + 'static, { self.spawn_internal::(Box::pin(future), None) } /// Enqueues the given future to be run to completion on a background thread. /// The given label can be used to control the priority of the task in tests. pub fn spawn_labeled( &self, label: TaskLabel, future: impl Future + Send + 'static, ) -> Task where R: Send + 'static, { self.spawn_internal::(Box::pin(future), Some(label)) } fn spawn_internal( &self, future: AnyFuture, label: Option, ) -> Task { let dispatcher = self.dispatcher.clone(); let (runnable, task) = async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label)); runnable.schedule(); Task(TaskState::Spawned(task)) } /// Used by the test harness to run an async test in a synchronous fashion. #[cfg(any(test, feature = "test-support"))] #[track_caller] pub fn block_test(&self, future: impl Future) -> R { if let Ok(value) = self.block_internal(false, future, None) { value } else { unreachable!() } } /// Block the current thread until the given future resolves. /// Consider using `block_with_timeout` instead. pub fn block(&self, future: impl Future) -> R { if let Ok(value) = self.block_internal(true, future, None) { value } else { unreachable!() } } #[cfg(not(any(test, feature = "test-support")))] pub(crate) fn block_internal( &self, _background_only: bool, future: impl Future, timeout: Option, ) -> Result> { use std::time::Instant; let mut future = Box::pin(future); if timeout == Some(Duration::ZERO) { return Err(future); } let deadline = timeout.map(|timeout| Instant::now() + timeout); let unparker = self.dispatcher.unparker(); let waker = waker_fn(move || { unparker.unpark(); }); let mut cx = std::task::Context::from_waker(&waker); loop { match future.as_mut().poll(&mut cx) { Poll::Ready(result) => return Ok(result), Poll::Pending => { let timeout = deadline.map(|deadline| deadline.saturating_duration_since(Instant::now())); if !self.dispatcher.park(timeout) && deadline.is_some_and(|deadline| deadline < Instant::now()) { return Err(future); } } } } } #[cfg(any(test, feature = "test-support"))] #[track_caller] pub(crate) fn block_internal( &self, background_only: bool, future: impl Future, timeout: Option, ) -> Result> { use std::sync::atomic::AtomicBool; let mut future = Box::pin(future); if timeout == Some(Duration::ZERO) { return Err(future); } let Some(dispatcher) = self.dispatcher.as_test() else { return Err(future); }; let mut max_ticks = if timeout.is_some() { dispatcher.gen_block_on_ticks() } else { usize::MAX }; let unparker = self.dispatcher.unparker(); let awoken = Arc::new(AtomicBool::new(false)); let waker = waker_fn({ let awoken = awoken.clone(); move || { awoken.store(true, SeqCst); unparker.unpark(); } }); let mut cx = std::task::Context::from_waker(&waker); loop { match future.as_mut().poll(&mut cx) { Poll::Ready(result) => return Ok(result), Poll::Pending => { if max_ticks == 0 { return Err(future); } max_ticks -= 1; if !dispatcher.tick(background_only) { if awoken.swap(false, SeqCst) { continue; } if !dispatcher.parking_allowed() { let mut backtrace_message = String::new(); let mut waiting_message = String::new(); if let Some(backtrace) = dispatcher.waiting_backtrace() { backtrace_message = format!("\nbacktrace of waiting future:\n{:?}", backtrace); } if let Some(waiting_hint) = dispatcher.waiting_hint() { waiting_message = format!("\n waiting on: {}\n", waiting_hint); } panic!( "parked with nothing left to run{waiting_message}{backtrace_message}", ) } self.dispatcher.park(None); } } } } } /// Block the current thread until the given future resolves /// or `duration` has elapsed. pub fn block_with_timeout( &self, duration: Duration, future: impl Future, ) -> Result> { self.block_internal(true, future, Some(duration)) } /// Scoped lets you start a number of tasks and waits /// for all of them to complete before returning. pub async fn scoped<'scope, F>(&self, scheduler: F) where F: FnOnce(&mut Scope<'scope>), { let mut scope = Scope::new(self.clone()); (scheduler)(&mut scope); let spawned = mem::take(&mut scope.futures) .into_iter() .map(|f| self.spawn(f)) .collect::>(); for task in spawned { task.await; } } /// Get the current time. /// /// Calling this instead of `std::time::Instant::now` allows the use /// of fake timers in tests. pub fn now(&self) -> Instant { self.dispatcher.now() } /// Returns a task that will complete after the given duration. /// Depending on other concurrent tasks the elapsed duration may be longer /// than requested. pub fn timer(&self, duration: Duration) -> Task<()> { if duration.is_zero() { return Task::ready(()); } let (runnable, task) = async_task::spawn(async move {}, { let dispatcher = self.dispatcher.clone(); move |runnable| dispatcher.dispatch_after(duration, runnable) }); runnable.schedule(); Task(TaskState::Spawned(task)) } /// in tests, start_waiting lets you indicate which task is waiting (for debugging only) #[cfg(any(test, feature = "test-support"))] pub fn start_waiting(&self) { self.dispatcher.as_test().unwrap().start_waiting(); } /// in tests, removes the debugging data added by start_waiting #[cfg(any(test, feature = "test-support"))] pub fn finish_waiting(&self) { self.dispatcher.as_test().unwrap().finish_waiting(); } /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable) #[cfg(any(test, feature = "test-support"))] pub fn simulate_random_delay(&self) -> impl Future { self.dispatcher.as_test().unwrap().simulate_random_delay() } /// in tests, indicate that a given task from `spawn_labeled` should run after everything else #[cfg(any(test, feature = "test-support"))] pub fn deprioritize(&self, task_label: TaskLabel) { self.dispatcher.as_test().unwrap().deprioritize(task_label) } /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready. #[cfg(any(test, feature = "test-support"))] pub fn advance_clock(&self, duration: Duration) { self.dispatcher.as_test().unwrap().advance_clock(duration) } /// in tests, run one task. #[cfg(any(test, feature = "test-support"))] pub fn tick(&self) -> bool { self.dispatcher.as_test().unwrap().tick(false) } /// in tests, run all tasks that are ready to run. If after doing so /// the test still has outstanding tasks, this will panic. (See also `allow_parking`) #[cfg(any(test, feature = "test-support"))] pub fn run_until_parked(&self) { self.dispatcher.as_test().unwrap().run_until_parked() } /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks. /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that /// do take real async time to run. #[cfg(any(test, feature = "test-support"))] pub fn allow_parking(&self) { self.dispatcher.as_test().unwrap().allow_parking(); } /// undoes the effect of [`allow_parking`]. #[cfg(any(test, feature = "test-support"))] pub fn forbid_parking(&self) { self.dispatcher.as_test().unwrap().forbid_parking(); } /// adds detail to the "parked with nothing let to run" message. #[cfg(any(test, feature = "test-support"))] pub fn set_waiting_hint(&self, msg: Option) { self.dispatcher.as_test().unwrap().set_waiting_hint(msg); } /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable #[cfg(any(test, feature = "test-support"))] pub fn rng(&self) -> StdRng { self.dispatcher.as_test().unwrap().rng() } /// How many CPUs are available to the dispatcher. pub fn num_cpus(&self) -> usize { #[cfg(any(test, feature = "test-support"))] return 4; #[cfg(not(any(test, feature = "test-support")))] return num_cpus::get(); } /// Whether we're on the main thread. pub fn is_main_thread(&self) -> bool { self.dispatcher.is_main_thread() } #[cfg(any(test, feature = "test-support"))] /// in tests, control the number of ticks that `block_with_timeout` will run before timing out. pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive) { self.dispatcher.as_test().unwrap().set_block_on_ticks(range); } } /// ForegroundExecutor runs things on the main thread. impl ForegroundExecutor { /// Creates a new ForegroundExecutor from the given PlatformDispatcher. pub fn new(dispatcher: Arc) -> Self { Self { dispatcher, not_send: PhantomData, } } /// Enqueues the given Task to run on the main thread at some point in the future. #[track_caller] pub fn spawn(&self, future: impl Future + 'static) -> Task where R: 'static, { let dispatcher = self.dispatcher.clone(); #[track_caller] fn inner( dispatcher: Arc, future: AnyLocalFuture, ) -> Task { let (runnable, task) = spawn_local_with_source_location(future, move |runnable| { dispatcher.dispatch_on_main_thread(runnable) }); runnable.schedule(); Task(TaskState::Spawned(task)) } inner::(dispatcher, Box::pin(future)) } } /// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics. /// /// Copy-modified from: /// https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405 #[track_caller] fn spawn_local_with_source_location( future: Fut, schedule: S, ) -> (Runnable<()>, async_task::Task) where Fut: Future + 'static, Fut::Output: 'static, S: async_task::Schedule<()> + Send + Sync + 'static, { #[inline] fn thread_id() -> ThreadId { std::thread_local! { static ID: ThreadId = thread::current().id(); } ID.try_with(|id| *id) .unwrap_or_else(|_| thread::current().id()) } struct Checked { id: ThreadId, inner: ManuallyDrop, location: &'static Location<'static>, } impl Drop for Checked { fn drop(&mut self) { assert!( self.id == thread_id(), "local task dropped by a thread that didn't spawn it. Task spawned at {}", self.location ); unsafe { ManuallyDrop::drop(&mut self.inner); } } } impl Future for Checked { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { assert!( self.id == thread_id(), "local task polled by a thread that didn't spawn it. Task spawned at {}", self.location ); unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) } } } // Wrap the future into one that checks which thread it's on. let future = Checked { id: thread_id(), inner: ManuallyDrop::new(future), location: Location::caller(), }; unsafe { async_task::spawn_unchecked(future, schedule) } } /// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`]. pub struct Scope<'a> { executor: BackgroundExecutor, futures: Vec + Send + 'static>>>, tx: Option>, rx: mpsc::Receiver<()>, lifetime: PhantomData<&'a ()>, } impl<'a> Scope<'a> { fn new(executor: BackgroundExecutor) -> Self { let (tx, rx) = mpsc::channel(1); Self { executor, tx: Some(tx), rx, futures: Default::default(), lifetime: PhantomData, } } /// How many CPUs are available to the dispatcher. pub fn num_cpus(&self) -> usize { self.executor.num_cpus() } /// Spawn a future into this scope. pub fn spawn(&mut self, f: F) where F: Future + Send + 'a, { let tx = self.tx.clone().unwrap(); // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because // dropping this `Scope` blocks until all of the futures have resolved. let f = unsafe { mem::transmute::< Pin + Send + 'a>>, Pin + Send + 'static>>, >(Box::pin(async move { f.await; drop(tx); })) }; self.futures.push(f); } } impl<'a> Drop for Scope<'a> { fn drop(&mut self) { self.tx.take().unwrap(); // Wait until the channel is closed, which means that all of the spawned // futures have resolved. self.executor.block(self.rx.next()); } }