zbus/connection/socket/
mod.rs

1#[cfg(feature = "p2p")]
2pub mod channel;
3#[cfg(feature = "p2p")]
4pub use channel::Channel;
5
6mod split;
7pub use split::{BoxedSplit, Split};
8
9#[cfg(unix)]
10pub(crate) mod command;
11#[cfg(unix)]
12pub(crate) use command::Command;
13mod tcp;
14mod unix;
15mod vsock;
16
17#[cfg(not(feature = "tokio"))]
18use async_io::Async;
19#[cfg(not(feature = "tokio"))]
20use std::sync::Arc;
21use std::{io, mem};
22use tracing::trace;
23
24use crate::{
25    conn::AuthMechanism,
26    fdo::ConnectionCredentials,
27    message::{
28        header::{MAX_MESSAGE_SIZE, MIN_MESSAGE_SIZE},
29        PrimaryHeader,
30    },
31    padding_for_8_bytes, Message,
32};
33#[cfg(unix)]
34use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
35use zvariant::{
36    serialized::{self, Context},
37    Endian,
38};
39
40#[cfg(unix)]
41type RecvmsgResult = io::Result<(usize, Vec<OwnedFd>)>;
42
43#[cfg(not(unix))]
44type RecvmsgResult = io::Result<usize>;
45
46/// Trait representing some transport layer over which the DBus protocol can be used.
47///
48/// In order to allow simultaneous reading and writing, this trait requires you to split the socket
49/// into a read half and a write half. The reader and writer halves can be any types that implement
50/// [`ReadHalf`] and [`WriteHalf`] respectively.
51///
52/// The crate provides implementations for `async_io` and `tokio`'s `UnixStream` wrappers if you
53/// enable the corresponding crate features (`async_io` is enabled by default).
54///
55/// You can implement it manually to integrate with other runtimes or other dbus transports.  Feel
56/// free to submit pull requests to add support for more runtimes to zbus itself so rust's orphan
57/// rules don't force the use of a wrapper struct (and to avoid duplicating the work across many
58/// projects).
59pub trait Socket {
60    type ReadHalf: ReadHalf;
61    type WriteHalf: WriteHalf;
62
63    /// Split the socket into a read half and a write half.
64    fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf>
65    where
66        Self: Sized;
67}
68
69/// The read half of a socket.
70///
71/// See [`Socket`] for more details.
72#[async_trait::async_trait]
73pub trait ReadHalf: std::fmt::Debug + Send + Sync + 'static {
74    /// Receive a message on the socket.
75    ///
76    /// This is the higher-level method to receive a full D-Bus message.
77    ///
78    /// The default implementation uses `recvmsg` to receive the message. Implementers should
79    /// override either this or `recvmsg`. Note that if you override this method, zbus will not be
80    /// able perform an authentication handshake and hence will skip the handshake. Therefore your
81    /// implementation will only be useful for pre-authenticated connections or connections that do
82    /// not require authentication.
83    ///
84    /// # Parameters
85    ///
86    /// - `seq`: The sequence number of the message. The returned message should have this sequence.
87    /// - `already_received_bytes`: Sometimes, zbus already received some bytes from the socket
88    ///   belonging to the first message(s) (as part of the connection handshake process). This is
89    ///   the buffer containing those bytes (if any). If you're implementing this method, most
90    ///   likely you can safely ignore this parameter.
91    /// - `already_received_fds`: Same goes for file descriptors belonging to first messages.
92    async fn receive_message(
93        &mut self,
94        seq: u64,
95        already_received_bytes: &mut Vec<u8>,
96        #[cfg(unix)] already_received_fds: &mut Vec<OwnedFd>,
97    ) -> crate::Result<Message> {
98        #[cfg(unix)]
99        let mut fds = vec![];
100        let mut bytes = if already_received_bytes.len() < MIN_MESSAGE_SIZE {
101            let mut bytes = vec![];
102            if !already_received_bytes.is_empty() {
103                mem::swap(already_received_bytes, &mut bytes);
104            }
105            let mut pos = bytes.len();
106            bytes.resize(MIN_MESSAGE_SIZE, 0);
107            // We don't have enough data to make a proper message header yet.
108            // Some partial read may be in raw_in_buffer, so we try to complete it
109            // until we have MIN_MESSAGE_SIZE bytes
110            //
111            // Given that MIN_MESSAGE_SIZE is 16, this codepath is actually extremely unlikely
112            // to be taken more than once
113            while pos < MIN_MESSAGE_SIZE {
114                let res = self.recvmsg(&mut bytes[pos..]).await?;
115                let len = {
116                    #[cfg(unix)]
117                    {
118                        fds.extend(res.1);
119                        res.0
120                    }
121                    #[cfg(not(unix))]
122                    {
123                        res
124                    }
125                };
126                pos += len;
127                if len == 0 {
128                    return Err(std::io::Error::new(
129                        std::io::ErrorKind::UnexpectedEof,
130                        "failed to receive message",
131                    )
132                    .into());
133                }
134            }
135
136            bytes
137        } else {
138            already_received_bytes.drain(..MIN_MESSAGE_SIZE).collect()
139        };
140
141        let (primary_header, fields_len) = PrimaryHeader::read(&bytes)?;
142        let header_len = MIN_MESSAGE_SIZE + fields_len as usize;
143        let body_padding = padding_for_8_bytes(header_len);
144        let body_len = primary_header.body_len() as usize;
145        let total_len = header_len + body_padding + body_len;
146        if total_len > MAX_MESSAGE_SIZE {
147            return Err(crate::Error::ExcessData);
148        }
149
150        // By this point we have a full primary header, so we know the exact length of the complete
151        // message.
152        if !already_received_bytes.is_empty() {
153            // still have some bytes buffered.
154            let pending = total_len - bytes.len();
155            let to_take = std::cmp::min(pending, already_received_bytes.len());
156            bytes.extend(already_received_bytes.drain(..to_take));
157        }
158        let mut pos = bytes.len();
159        bytes.resize(total_len, 0);
160
161        // Read the rest, if any
162        while pos < total_len {
163            let res = self.recvmsg(&mut bytes[pos..]).await?;
164            let read = {
165                #[cfg(unix)]
166                {
167                    fds.extend(res.1);
168                    res.0
169                }
170                #[cfg(not(unix))]
171                {
172                    res
173                }
174            };
175            pos += read;
176            if read == 0 {
177                return Err(crate::Error::InputOutput(
178                    std::io::Error::new(
179                        std::io::ErrorKind::UnexpectedEof,
180                        "failed to receive message",
181                    )
182                    .into(),
183                ));
184            }
185        }
186
187        // If we reach here, the message is complete; return it
188        let endian = Endian::from(primary_header.endian_sig());
189
190        #[cfg(unix)]
191        if !already_received_fds.is_empty() {
192            use crate::message::header::PRIMARY_HEADER_SIZE;
193
194            let ctxt = Context::new_dbus(endian, PRIMARY_HEADER_SIZE);
195            let encoded_fields =
196                serialized::Data::new(&bytes[PRIMARY_HEADER_SIZE..header_len], ctxt);
197            let fields: crate::message::Fields<'_> = encoded_fields.deserialize()?.0;
198            let num_required_fds = match fields.unix_fds {
199                Some(num_fds) => num_fds as usize,
200                _ => 0,
201            };
202            let num_pending = num_required_fds
203                .checked_sub(fds.len())
204                .ok_or_else(|| crate::Error::ExcessData)?;
205            // If we had previously received FDs, `num_pending` has to be > 0
206            if num_pending == 0 {
207                return Err(crate::Error::MissingParameter("Missing file descriptors"));
208            }
209            // All previously received FDs must go first in the list.
210            let mut already_received: Vec<_> = already_received_fds.drain(..num_pending).collect();
211            mem::swap(&mut already_received, &mut fds);
212            fds.extend(already_received);
213        }
214
215        let ctxt = Context::new_dbus(endian, 0);
216        #[cfg(unix)]
217        let bytes = serialized::Data::new_fds(bytes, ctxt, fds);
218        #[cfg(not(unix))]
219        let bytes = serialized::Data::new(bytes, ctxt);
220        Message::from_raw_parts(bytes, seq)
221    }
222
223    /// Attempt to receive bytes from the socket.
224    ///
225    /// On success, returns the number of bytes read as well as a `Vec` containing
226    /// any associated file descriptors.
227    ///
228    /// The default implementation simply panics. Implementers must override either `read_message`
229    /// or this method.
230    async fn recvmsg(&mut self, _buf: &mut [u8]) -> RecvmsgResult {
231        unimplemented!("`ReadHalf` implementers must either override `read_message` or `recvmsg`");
232    }
233
234    /// Return whether passing file descriptors is supported.
235    ///
236    /// Default implementation returns `false`.
237    fn can_pass_unix_fd(&self) -> bool {
238        false
239    }
240
241    /// The peer credentials.
242    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
243        Ok(ConnectionCredentials::default())
244    }
245
246    /// The authentication mechanism to use for this socket on the target OS.
247    ///
248    /// Default is `AuthMechanism::External`.
249    fn auth_mechanism(&self) -> AuthMechanism {
250        AuthMechanism::External
251    }
252}
253
254/// The write half of a socket.
255///
256/// See [`Socket`] for more details.
257#[async_trait::async_trait]
258pub trait WriteHalf: std::fmt::Debug + Send + Sync + 'static {
259    /// Send a message on the socket.
260    ///
261    /// This is the higher-level method to send a full D-Bus message.
262    ///
263    /// The default implementation uses `sendmsg` to send the message. Implementers should override
264    /// either this or `sendmsg`.
265    async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
266        let data = msg.data();
267        let serial = msg.primary_header().serial_num();
268
269        trace!("Sending message: {:?}", msg);
270        let mut pos = 0;
271        while pos < data.len() {
272            #[cfg(unix)]
273            let fds = if pos == 0 {
274                data.fds().iter().map(|f| f.as_fd()).collect()
275            } else {
276                vec![]
277            };
278            pos += self
279                .sendmsg(
280                    &data[pos..],
281                    #[cfg(unix)]
282                    &fds,
283                )
284                .await?;
285        }
286        trace!("Sent message with serial: {}", serial);
287
288        Ok(())
289    }
290
291    /// Attempt to send a message on the socket
292    ///
293    /// On success, return the number of bytes written. There may be a partial write, in
294    /// which case the caller is responsible for sending the remaining data by calling this
295    /// method again until everything is written or it returns an error of kind `WouldBlock`.
296    ///
297    /// If at least one byte has been written, then all the provided file descriptors will
298    /// have been sent as well, and should not be provided again in subsequent calls.
299    ///
300    /// If the underlying transport does not support transmitting file descriptors, this
301    /// will return `Err(ErrorKind::InvalidInput)`.
302    ///
303    /// The default implementation simply panics. Implementers must override either `send_message`
304    /// or this method.
305    async fn sendmsg(
306        &mut self,
307        _buffer: &[u8],
308        #[cfg(unix)] _fds: &[BorrowedFd<'_>],
309    ) -> io::Result<usize> {
310        unimplemented!("`WriteHalf` implementers must either override `send_message` or `sendmsg`");
311    }
312
313    /// The dbus daemon on `freebsd` and `dragonfly` currently requires sending the zero byte
314    /// as a separate message with SCM_CREDS, as part of the `EXTERNAL` authentication on unix
315    /// sockets. This method is used by the authentication machinery in zbus to send this
316    /// zero byte. Socket implementations based on unix sockets should implement this method.
317    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
318    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
319        Ok(None)
320    }
321
322    /// Close the socket.
323    ///
324    /// After this call, it is valid for all reading and writing operations to fail.
325    async fn close(&mut self) -> io::Result<()>;
326
327    /// Whether passing file descriptors is supported.
328    ///
329    /// Default implementation returns `false`.
330    fn can_pass_unix_fd(&self) -> bool {
331        false
332    }
333
334    /// The peer credentials.
335    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
336        Ok(ConnectionCredentials::default())
337    }
338}
339
340#[async_trait::async_trait]
341impl ReadHalf for Box<dyn ReadHalf> {
342    fn can_pass_unix_fd(&self) -> bool {
343        (**self).can_pass_unix_fd()
344    }
345
346    async fn receive_message(
347        &mut self,
348        seq: u64,
349        already_received_bytes: &mut Vec<u8>,
350        #[cfg(unix)] already_received_fds: &mut Vec<OwnedFd>,
351    ) -> crate::Result<Message> {
352        (**self)
353            .receive_message(
354                seq,
355                already_received_bytes,
356                #[cfg(unix)]
357                already_received_fds,
358            )
359            .await
360    }
361
362    async fn recvmsg(&mut self, buf: &mut [u8]) -> RecvmsgResult {
363        (**self).recvmsg(buf).await
364    }
365
366    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
367        (**self).peer_credentials().await
368    }
369
370    fn auth_mechanism(&self) -> AuthMechanism {
371        (**self).auth_mechanism()
372    }
373}
374
375#[async_trait::async_trait]
376impl WriteHalf for Box<dyn WriteHalf> {
377    async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
378        (**self).send_message(msg).await
379    }
380
381    async fn sendmsg(
382        &mut self,
383        buffer: &[u8],
384        #[cfg(unix)] fds: &[BorrowedFd<'_>],
385    ) -> io::Result<usize> {
386        (**self)
387            .sendmsg(
388                buffer,
389                #[cfg(unix)]
390                fds,
391            )
392            .await
393    }
394
395    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
396    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
397        (**self).send_zero_byte().await
398    }
399
400    async fn close(&mut self) -> io::Result<()> {
401        (**self).close().await
402    }
403
404    fn can_pass_unix_fd(&self) -> bool {
405        (**self).can_pass_unix_fd()
406    }
407
408    async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
409        (**self).peer_credentials().await
410    }
411}
412
413#[cfg(not(feature = "tokio"))]
414impl<T> Socket for Async<T>
415where
416    T: std::fmt::Debug + Send + Sync,
417    Arc<Async<T>>: ReadHalf + WriteHalf,
418{
419    type ReadHalf = Arc<Async<T>>;
420    type WriteHalf = Arc<Async<T>>;
421
422    fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf> {
423        let arc = Arc::new(self);
424
425        Split {
426            read: arc.clone(),
427            write: arc,
428        }
429    }
430}