zbus/
socket_reader.rs

1use std::{
2    collections::HashMap,
3    sync::{self, Arc},
4};
5
6use futures_util::future::poll_fn;
7use tracing::{debug, instrument, trace};
8
9use crate::{
10    async_lock::Mutex, raw::Connection as RawConnection, Executor, MsgBroadcaster, OwnedMatchRule,
11    Socket, Task,
12};
13
14#[derive(Debug)]
15pub(crate) struct SocketReader {
16    raw_conn: Arc<sync::Mutex<RawConnection<Box<dyn Socket>>>>,
17    senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
18}
19
20impl SocketReader {
21    pub fn new(
22        raw_conn: Arc<sync::Mutex<RawConnection<Box<dyn Socket>>>>,
23        senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
24    ) -> Self {
25        Self { raw_conn, senders }
26    }
27
28    pub fn spawn(self, executor: &Executor<'_>) -> Task<()> {
29        executor.spawn(self.receive_msg(), "socket reader")
30    }
31
32    // Keep receiving messages and put them on the queue.
33    #[instrument(name = "socket reader", skip(self))]
34    async fn receive_msg(self) {
35        loop {
36            trace!("Waiting for message on the socket..");
37            let msg = {
38                poll_fn(|cx| {
39                    let mut raw_conn = self.raw_conn.lock().expect("poisoned lock");
40                    raw_conn.try_receive_message(cx)
41                })
42                .await
43                .map(Arc::new)
44            };
45            match &msg {
46                Ok(msg) => trace!("Message received on the socket: {:?}", msg),
47                Err(e) => trace!("Error reading from the socket: {:?}", e),
48            };
49
50            let mut senders = self.senders.lock().await;
51            for (rule, sender) in &*senders {
52                if let Ok(msg) = &msg {
53                    if let Some(rule) = rule.as_ref() {
54                        match rule.matches(msg) {
55                            Ok(true) => (),
56                            Ok(false) => continue,
57                            Err(e) => {
58                                debug!("Error matching message against rule: {:?}", e);
59
60                                continue;
61                            }
62                        }
63                    }
64                }
65
66                if let Err(e) = sender.broadcast(msg.clone()).await {
67                    // An error would be due to either of these:
68                    //
69                    // 1. the channel is closed.
70                    // 2. No active receivers.
71                    //
72                    // In either case, just log it.
73                    trace!(
74                        "Error broadcasting message to stream for `{:?}`: {:?}",
75                        rule,
76                        e
77                    );
78                }
79            }
80            trace!("Broadcasted to all streams: {:?}", msg);
81
82            if msg.is_err() {
83                senders.clear();
84                trace!("Socket reading task stopped");
85
86                return;
87            }
88        }
89    }
90}