use crate::signal::os::{OsExtraData, OsStorage};
use crate::sync::watch;
use crate::util::once_cell::OnceCell;
use std::ops;
use std::sync::atomic::{AtomicBool, Ordering};
pub(crate) type EventId = usize;
#[derive(Debug)]
pub(crate) struct EventInfo {
pending: AtomicBool,
tx: watch::Sender<()>,
}
impl Default for EventInfo {
fn default() -> Self {
let (tx, _rx) = watch::channel(());
Self {
pending: AtomicBool::new(false),
tx,
}
}
}
pub(crate) trait Storage {
fn event_info(&self, id: EventId) -> Option<&EventInfo>;
fn for_each<'a, F>(&'a self, f: F)
where
F: FnMut(&'a EventInfo);
}
impl Storage for Vec<EventInfo> {
fn event_info(&self, id: EventId) -> Option<&EventInfo> {
self.get(id)
}
fn for_each<'a, F>(&'a self, f: F)
where
F: FnMut(&'a EventInfo),
{
self.iter().for_each(f);
}
}
pub(crate) trait Init {
fn init() -> Self;
}
#[derive(Debug)]
pub(crate) struct Registry<S> {
storage: S,
}
impl<S> Registry<S> {
fn new(storage: S) -> Self {
Self { storage }
}
}
impl<S: Storage> Registry<S> {
fn register_listener(&self, event_id: EventId) -> watch::Receiver<()> {
self.storage
.event_info(event_id)
.unwrap_or_else(|| panic!("invalid event_id: {event_id}"))
.tx
.subscribe()
}
fn record_event(&self, event_id: EventId) {
if let Some(event_info) = self.storage.event_info(event_id) {
event_info.pending.store(true, Ordering::SeqCst);
}
}
fn broadcast(&self) -> bool {
let mut did_notify = false;
self.storage.for_each(|event_info| {
if !event_info.pending.swap(false, Ordering::SeqCst) {
return;
}
if event_info.tx.send(()).is_ok() {
did_notify = true;
}
});
did_notify
}
}
pub(crate) struct Globals {
extra: OsExtraData,
registry: Registry<OsStorage>,
}
impl ops::Deref for Globals {
type Target = OsExtraData;
fn deref(&self) -> &Self::Target {
&self.extra
}
}
impl Globals {
pub(crate) fn register_listener(&self, event_id: EventId) -> watch::Receiver<()> {
self.registry.register_listener(event_id)
}
pub(crate) fn record_event(&self, event_id: EventId) {
self.registry.record_event(event_id);
}
pub(crate) fn broadcast(&self) -> bool {
self.registry.broadcast()
}
#[cfg(unix)]
pub(crate) fn storage(&self) -> &OsStorage {
&self.registry.storage
}
}
fn globals_init() -> Globals
where
OsExtraData: 'static + Send + Sync + Init,
OsStorage: 'static + Send + Sync + Init,
{
Globals {
extra: OsExtraData::init(),
registry: Registry::new(OsStorage::init()),
}
}
pub(crate) fn globals() -> &'static Globals
where
OsExtraData: 'static + Send + Sync + Init,
OsStorage: 'static + Send + Sync + Init,
{
static GLOBALS: OnceCell<Globals> = OnceCell::new();
GLOBALS.get(globals_init)
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use crate::runtime::{self, Runtime};
use crate::sync::{oneshot, watch};
use futures::future;
#[test]
fn smoke() {
let rt = rt();
rt.block_on(async move {
let registry = Registry::new(vec![
EventInfo::default(),
EventInfo::default(),
EventInfo::default(),
]);
let first = registry.register_listener(0);
let second = registry.register_listener(1);
let third = registry.register_listener(2);
let (fire, wait) = oneshot::channel();
crate::spawn(async {
wait.await.expect("wait failed");
registry.record_event(0);
registry.record_event(0);
registry.record_event(1);
registry.record_event(1);
registry.broadcast();
for _ in 0..100 {
crate::task::yield_now().await;
}
registry.record_event(0);
registry.broadcast();
drop(registry);
});
let _ = fire.send(());
let all = future::join3(collect(first), collect(second), collect(third));
let (first_results, second_results, third_results) = all.await;
assert_eq!(2, first_results.len());
assert_eq!(1, second_results.len());
assert_eq!(0, third_results.len());
});
}
#[test]
#[should_panic = "invalid event_id: 1"]
fn register_panics_on_invalid_input() {
let registry = Registry::new(vec![EventInfo::default()]);
registry.register_listener(1);
}
#[test]
fn record_invalid_event_does_nothing() {
let registry = Registry::new(vec![EventInfo::default()]);
registry.record_event(1302);
}
#[test]
fn broadcast_returns_if_at_least_one_event_fired() {
let registry = Registry::new(vec![EventInfo::default(), EventInfo::default()]);
registry.record_event(0);
assert!(!registry.broadcast());
let first = registry.register_listener(0);
let second = registry.register_listener(1);
registry.record_event(0);
assert!(registry.broadcast());
drop(first);
registry.record_event(0);
assert!(!registry.broadcast());
drop(second);
}
fn rt() -> Runtime {
runtime::Builder::new_current_thread()
.enable_time()
.build()
.unwrap()
}
async fn collect(mut rx: watch::Receiver<()>) -> Vec<()> {
let mut ret = vec![];
while let Ok(v) = rx.changed().await {
ret.push(v);
}
ret
}
}