zbus/raw/
socket.rs

1#[cfg(not(feature = "tokio"))]
2use async_io::Async;
3#[cfg(not(feature = "tokio"))]
4use futures_core::ready;
5#[cfg(unix)]
6use std::io::{IoSlice, IoSliceMut};
7#[cfg(feature = "tokio")]
8use std::pin::Pin;
9use std::{
10    io,
11    task::{Context, Poll},
12};
13#[cfg(not(feature = "tokio"))]
14use std::{
15    io::{Read, Write},
16    net::TcpStream,
17};
18
19#[cfg(all(windows, not(feature = "tokio")))]
20use uds_windows::UnixStream;
21
22#[cfg(unix)]
23use nix::{
24    cmsg_space,
25    sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, UnixAddr},
26};
27#[cfg(unix)]
28use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
29
30#[cfg(all(unix, not(feature = "tokio")))]
31use std::os::unix::net::UnixStream;
32
33#[cfg(unix)]
34use crate::{utils::FDS_MAX, OwnedFd};
35
36#[cfg(unix)]
37fn fd_recvmsg(fd: RawFd, buffer: &mut [u8]) -> io::Result<(usize, Vec<OwnedFd>)> {
38    let mut iov = [IoSliceMut::new(buffer)];
39    let mut cmsgspace = cmsg_space!([RawFd; FDS_MAX]);
40
41    let msg = recvmsg::<UnixAddr>(fd, &mut iov, Some(&mut cmsgspace), MsgFlags::empty())?;
42    if msg.bytes == 0 {
43        return Err(io::Error::new(
44            io::ErrorKind::BrokenPipe,
45            "failed to read from socket",
46        ));
47    }
48    let mut fds = vec![];
49    for cmsg in msg.cmsgs() {
50        #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
51        if let ControlMessageOwned::ScmCreds(_) = cmsg {
52            continue;
53        }
54        if let ControlMessageOwned::ScmRights(fd) = cmsg {
55            fds.extend(fd.iter().map(|&f| unsafe { OwnedFd::from_raw_fd(f) }));
56        } else {
57            return Err(io::Error::new(
58                io::ErrorKind::InvalidData,
59                "unexpected CMSG kind",
60            ));
61        }
62    }
63    Ok((msg.bytes, fds))
64}
65
66#[cfg(unix)]
67fn fd_sendmsg(fd: RawFd, buffer: &[u8], fds: &[RawFd]) -> io::Result<usize> {
68    let cmsg = if !fds.is_empty() {
69        vec![ControlMessage::ScmRights(fds)]
70    } else {
71        vec![]
72    };
73    let iov = [IoSlice::new(buffer)];
74    match sendmsg::<UnixAddr>(fd, &iov, &cmsg, MsgFlags::empty(), None) {
75        // can it really happen?
76        Ok(0) => Err(io::Error::new(
77            io::ErrorKind::WriteZero,
78            "failed to write to buffer",
79        )),
80        Ok(n) => Ok(n),
81        Err(e) => Err(e.into()),
82    }
83}
84
85#[cfg(unix)]
86fn get_unix_pid(fd: &impl AsRawFd) -> io::Result<Option<u32>> {
87    #[cfg(any(target_os = "android", target_os = "linux"))]
88    {
89        use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
90
91        let fd = fd.as_raw_fd();
92        getsockopt(fd, PeerCredentials)
93            .map(|creds| Some(creds.pid() as _))
94            .map_err(|e| e.into())
95    }
96
97    #[cfg(any(
98        target_os = "macos",
99        target_os = "ios",
100        target_os = "freebsd",
101        target_os = "dragonfly",
102        target_os = "openbsd",
103        target_os = "netbsd"
104    ))]
105    {
106        let _ = fd;
107        // FIXME
108        Ok(None)
109    }
110}
111
112#[cfg(unix)]
113fn get_unix_uid(fd: &impl AsRawFd) -> io::Result<Option<u32>> {
114    let fd = fd.as_raw_fd();
115
116    #[cfg(any(target_os = "android", target_os = "linux"))]
117    {
118        use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
119
120        getsockopt(fd, PeerCredentials)
121            .map(|creds| Some(creds.uid()))
122            .map_err(|e| e.into())
123    }
124
125    #[cfg(any(
126        target_os = "macos",
127        target_os = "ios",
128        target_os = "freebsd",
129        target_os = "dragonfly",
130        target_os = "openbsd",
131        target_os = "netbsd"
132    ))]
133    {
134        nix::unistd::getpeereid(fd)
135            .map(|(uid, _)| Some(uid.into()))
136            .map_err(|e| e.into())
137    }
138}
139
140// Send 0 byte as a separate SCM_CREDS message.
141#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
142fn send_zero_byte(fd: &impl AsRawFd) -> io::Result<usize> {
143    let fd = fd.as_raw_fd();
144    let iov = [std::io::IoSlice::new(b"\0")];
145    sendmsg::<()>(
146        fd,
147        &iov,
148        &[ControlMessage::ScmCreds],
149        MsgFlags::empty(),
150        None,
151    )
152    .map_err(|e| e.into())
153}
154
155#[cfg(unix)]
156type PollRecvmsg = Poll<io::Result<(usize, Vec<OwnedFd>)>>;
157
158#[cfg(not(unix))]
159type PollRecvmsg = Poll<io::Result<usize>>;
160
161/// Trait representing some transport layer over which the DBus protocol can be used
162///
163/// The crate provides implementations for `async_io` and `tokio`'s `UnixStream` wrappers if you
164/// enable the corresponding crate features (`async_io` is enabled by default).
165///
166/// You can implement it manually to integrate with other runtimes or other dbus transports.  Feel
167/// free to submit pull requests to add support for more runtimes to zbus itself so rust's orphan
168/// rules don't force the use of a wrapper struct (and to avoid duplicating the work across many
169/// projects).
170pub trait Socket: std::fmt::Debug + Send + Sync {
171    /// Supports passing file descriptors.
172    fn can_pass_unix_fd(&self) -> bool {
173        true
174    }
175
176    /// Attempt to receive a message from the socket.
177    ///
178    /// On success, returns the number of bytes read as well as a `Vec` containing
179    /// any associated file descriptors.
180    fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg;
181
182    /// Attempt to send a message on the socket
183    ///
184    /// On success, return the number of bytes written. There may be a partial write, in
185    /// which case the caller is responsible of sending the remaining data by calling this
186    /// method again until everything is written or it returns an error of kind `WouldBlock`.
187    ///
188    /// If at least one byte has been written, then all the provided file descriptors will
189    /// have been sent as well, and should not be provided again in subsequent calls.
190    ///
191    /// If the underlying transport does not support transmitting file descriptors, this
192    /// will return `Err(ErrorKind::InvalidInput)`.
193    fn poll_sendmsg(
194        &mut self,
195        cx: &mut Context<'_>,
196        buffer: &[u8],
197        #[cfg(unix)] fds: &[RawFd],
198    ) -> Poll<io::Result<usize>>;
199
200    /// Close the socket.
201    ///
202    /// After this call, it is valid for all reading and writing operations to fail.
203    fn close(&self) -> io::Result<()>;
204
205    /// Return the peer PID.
206    fn peer_pid(&self) -> io::Result<Option<u32>> {
207        Ok(None)
208    }
209
210    /// Return the peer process SID, if any.
211    #[cfg(windows)]
212    fn peer_sid(&self) -> Option<String> {
213        None
214    }
215
216    /// Return the User ID, if any.
217    #[cfg(unix)]
218    fn uid(&self) -> io::Result<Option<u32>> {
219        Ok(None)
220    }
221
222    /// The dbus daemon on `freebsd` and `dragonfly` currently requires sending the zero byte
223    /// as a separate message with SCM_CREDS, as part of the `EXTERNAL` authentication on unix
224    /// sockets. This method is used by the authentication machinery in zbus to send this
225    /// zero byte. Socket implementations based on unix sockets should implement this method.
226    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
227    fn send_zero_byte(&self) -> io::Result<Option<usize>> {
228        Ok(None)
229    }
230}
231
232impl Socket for Box<dyn Socket> {
233    fn can_pass_unix_fd(&self) -> bool {
234        (**self).can_pass_unix_fd()
235    }
236
237    fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
238        (**self).poll_recvmsg(cx, buf)
239    }
240
241    fn poll_sendmsg(
242        &mut self,
243        cx: &mut Context<'_>,
244        buffer: &[u8],
245        #[cfg(unix)] fds: &[RawFd],
246    ) -> Poll<io::Result<usize>> {
247        (**self).poll_sendmsg(
248            cx,
249            buffer,
250            #[cfg(unix)]
251            fds,
252        )
253    }
254
255    fn close(&self) -> io::Result<()> {
256        (**self).close()
257    }
258
259    fn peer_pid(&self) -> io::Result<Option<u32>> {
260        (**self).peer_pid()
261    }
262
263    #[cfg(windows)]
264    fn peer_sid(&self) -> Option<String> {
265        (&**self).peer_sid()
266    }
267
268    #[cfg(unix)]
269    fn uid(&self) -> io::Result<Option<u32>> {
270        (**self).uid()
271    }
272
273    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
274    fn send_zero_byte(&self) -> io::Result<Option<usize>> {
275        (**self).send_zero_byte()
276    }
277}
278
279#[cfg(all(unix, not(feature = "tokio")))]
280impl Socket for Async<UnixStream> {
281    fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
282        let (len, fds) = loop {
283            match fd_recvmsg(self.as_raw_fd(), buf) {
284                Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
285                Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_readable(cx) {
286                    Poll::Pending => return Poll::Pending,
287                    Poll::Ready(res) => res?,
288                },
289                v => break v?,
290            }
291        };
292        Poll::Ready(Ok((len, fds)))
293    }
294
295    fn poll_sendmsg(
296        &mut self,
297        cx: &mut Context<'_>,
298        buffer: &[u8],
299        #[cfg(unix)] fds: &[RawFd],
300    ) -> Poll<io::Result<usize>> {
301        loop {
302            match fd_sendmsg(
303                self.as_raw_fd(),
304                buffer,
305                #[cfg(unix)]
306                fds,
307            ) {
308                Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
309                Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_writable(cx) {
310                    Poll::Pending => return Poll::Pending,
311                    Poll::Ready(res) => res?,
312                },
313                v => return Poll::Ready(v),
314            }
315        }
316    }
317
318    fn close(&self) -> io::Result<()> {
319        self.get_ref().shutdown(std::net::Shutdown::Both)
320    }
321
322    fn peer_pid(&self) -> io::Result<Option<u32>> {
323        get_unix_pid(self)
324    }
325
326    #[cfg(unix)]
327    fn uid(&self) -> io::Result<Option<u32>> {
328        get_unix_uid(self)
329    }
330
331    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
332    fn send_zero_byte(&self) -> io::Result<Option<usize>> {
333        send_zero_byte(self).map(Some)
334    }
335}
336
337#[cfg(all(unix, feature = "tokio"))]
338impl Socket for tokio::net::UnixStream {
339    fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
340        loop {
341            match self.try_io(tokio::io::Interest::READABLE, || {
342                fd_recvmsg(self.as_raw_fd(), buf)
343            }) {
344                Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
345                Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_read_ready(cx) {
346                    Poll::Pending => return Poll::Pending,
347                    Poll::Ready(res) => res?,
348                },
349                v => return Poll::Ready(v),
350            }
351        }
352    }
353
354    fn poll_sendmsg(
355        &mut self,
356        cx: &mut Context<'_>,
357        buffer: &[u8],
358        #[cfg(unix)] fds: &[RawFd],
359    ) -> Poll<io::Result<usize>> {
360        loop {
361            match self.try_io(tokio::io::Interest::WRITABLE, || {
362                fd_sendmsg(
363                    self.as_raw_fd(),
364                    buffer,
365                    #[cfg(unix)]
366                    fds,
367                )
368            }) {
369                Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
370                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
371                    match self.poll_write_ready(cx) {
372                        Poll::Pending => return Poll::Pending,
373                        Poll::Ready(res) => res?,
374                    }
375                }
376                v => return Poll::Ready(v),
377            }
378        }
379    }
380
381    fn close(&self) -> io::Result<()> {
382        // FIXME: This should call `tokio::net::UnixStream::poll_shutdown` but this method is not
383        // async-friendly. At the next API break, we should fix this.
384        Ok(())
385    }
386
387    fn peer_pid(&self) -> io::Result<Option<u32>> {
388        get_unix_pid(self)
389    }
390
391    #[cfg(unix)]
392    fn uid(&self) -> io::Result<Option<u32>> {
393        get_unix_uid(self)
394    }
395
396    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
397    fn send_zero_byte(&self) -> io::Result<Option<usize>> {
398        send_zero_byte(self).map(Some)
399    }
400}
401
402#[cfg(all(windows, not(feature = "tokio")))]
403impl Socket for Async<UnixStream> {
404    fn can_pass_unix_fd(&self) -> bool {
405        false
406    }
407
408    fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
409        loop {
410            match (&mut *self).get_mut().read(buf) {
411                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
412                Err(e) => return Poll::Ready(Err(e)),
413                Ok(len) => {
414                    let ret = len;
415                    return Poll::Ready(Ok(ret));
416                }
417            }
418            ready!(self.poll_readable(cx))?;
419        }
420    }
421
422    fn poll_sendmsg(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
423        loop {
424            match (&mut *self).get_mut().write(buf) {
425                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
426                res => return Poll::Ready(res),
427            }
428            ready!(self.poll_writable(cx))?;
429        }
430    }
431
432    fn close(&self) -> io::Result<()> {
433        self.get_ref().shutdown(std::net::Shutdown::Both)
434    }
435
436    #[cfg(windows)]
437    fn peer_sid(&self) -> Option<String> {
438        use crate::win32::ProcessToken;
439
440        if let Ok(Some(pid)) = self.peer_pid() {
441            if let Ok(process_token) =
442                ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
443            {
444                return process_token.sid().ok();
445            }
446        }
447
448        None
449    }
450
451    fn peer_pid(&self) -> io::Result<Option<u32>> {
452        #[cfg(windows)]
453        {
454            use crate::win32::unix_stream_get_peer_pid;
455
456            Ok(Some(unix_stream_get_peer_pid(&self.get_ref())? as _))
457        }
458
459        #[cfg(unix)]
460        get_unix_pid(self)
461    }
462
463    #[cfg(unix)]
464    fn uid(&self) -> io::Result<Option<u32>> {
465        get_unix_uid(self)
466    }
467
468    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
469    fn send_zero_byte(&self) -> io::Result<Option<usize>> {
470        send_zero_byte(self).map(Some)
471    }
472}
473
474#[cfg(not(feature = "tokio"))]
475impl Socket for Async<TcpStream> {
476    fn can_pass_unix_fd(&self) -> bool {
477        false
478    }
479
480    fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
481        #[cfg(unix)]
482        let fds = vec![];
483
484        loop {
485            match (*self).get_mut().read(buf) {
486                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
487                Err(e) => return Poll::Ready(Err(e)),
488                Ok(len) => {
489                    #[cfg(unix)]
490                    let ret = (len, fds);
491                    #[cfg(not(unix))]
492                    let ret = len;
493                    return Poll::Ready(Ok(ret));
494                }
495            }
496            ready!(self.poll_readable(cx))?;
497        }
498    }
499
500    fn poll_sendmsg(
501        &mut self,
502        cx: &mut Context<'_>,
503        buf: &[u8],
504        #[cfg(unix)] fds: &[RawFd],
505    ) -> Poll<io::Result<usize>> {
506        #[cfg(unix)]
507        if !fds.is_empty() {
508            return Poll::Ready(Err(io::Error::new(
509                io::ErrorKind::InvalidInput,
510                "fds cannot be sent with a tcp stream",
511            )));
512        }
513
514        loop {
515            match (*self).get_mut().write(buf) {
516                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
517                res => return Poll::Ready(res),
518            }
519            ready!(self.poll_writable(cx))?;
520        }
521    }
522
523    fn close(&self) -> io::Result<()> {
524        self.get_ref().shutdown(std::net::Shutdown::Both)
525    }
526
527    #[cfg(windows)]
528    fn peer_sid(&self) -> Option<String> {
529        use crate::win32::{tcp_stream_get_peer_pid, ProcessToken};
530
531        if let Ok(pid) = tcp_stream_get_peer_pid(&self.get_ref()) {
532            if let Ok(process_token) = ProcessToken::open(if pid != 0 { Some(pid) } else { None }) {
533                return process_token.sid().ok();
534            }
535        }
536
537        None
538    }
539}
540
541#[cfg(feature = "tokio")]
542impl Socket for tokio::net::TcpStream {
543    fn can_pass_unix_fd(&self) -> bool {
544        false
545    }
546
547    fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
548        use tokio::io::{AsyncRead, ReadBuf};
549
550        let mut read_buf = ReadBuf::new(buf);
551        Pin::new(self).poll_read(cx, &mut read_buf).map(|res| {
552            res.map(|_| {
553                let ret = read_buf.filled().len();
554                #[cfg(unix)]
555                let ret = (ret, vec![]);
556
557                ret
558            })
559        })
560    }
561
562    fn poll_sendmsg(
563        &mut self,
564        cx: &mut Context<'_>,
565        buf: &[u8],
566        #[cfg(unix)] fds: &[RawFd],
567    ) -> Poll<io::Result<usize>> {
568        use tokio::io::AsyncWrite;
569
570        #[cfg(unix)]
571        if !fds.is_empty() {
572            return Poll::Ready(Err(io::Error::new(
573                io::ErrorKind::InvalidInput,
574                "fds cannot be sent with a tcp stream",
575            )));
576        }
577
578        Pin::new(self).poll_write(cx, buf)
579    }
580
581    fn close(&self) -> io::Result<()> {
582        // FIXME: This should call `tokio::net::TcpStream::poll_shutdown` but this method is not
583        // async-friendly. At the next API break, we should fix this.
584        Ok(())
585    }
586
587    #[cfg(windows)]
588    fn peer_sid(&self) -> Option<String> {
589        use crate::win32::{socket_addr_get_pid, ProcessToken};
590
591        let peer_addr = match self.peer_addr() {
592            Ok(addr) => addr,
593            Err(_) => return None,
594        };
595
596        if let Ok(pid) = socket_addr_get_pid(&peer_addr) {
597            if let Ok(process_token) = ProcessToken::open(if pid != 0 { Some(pid) } else { None }) {
598                return process_token.sid().ok();
599            }
600        }
601
602        None
603    }
604}
605
606#[cfg(all(feature = "vsock", not(feature = "tokio")))]
607impl Socket for Async<vsock::VsockStream> {
608    fn can_pass_unix_fd(&self) -> bool {
609        false
610    }
611
612    fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
613        #[cfg(unix)]
614        let fds = vec![];
615
616        loop {
617            match (*self).get_mut().read(buf) {
618                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
619                Err(e) => return Poll::Ready(Err(e)),
620                Ok(len) => {
621                    #[cfg(unix)]
622                    let ret = (len, fds);
623                    #[cfg(not(unix))]
624                    let ret = len;
625                    return Poll::Ready(Ok(ret));
626                }
627            }
628            ready!(self.poll_readable(cx))?;
629        }
630    }
631
632    fn poll_sendmsg(
633        &mut self,
634        cx: &mut Context<'_>,
635        buf: &[u8],
636        #[cfg(unix)] fds: &[RawFd],
637    ) -> Poll<io::Result<usize>> {
638        #[cfg(unix)]
639        if !fds.is_empty() {
640            return Poll::Ready(Err(io::Error::new(
641                io::ErrorKind::InvalidInput,
642                "fds cannot be sent with a tcp stream",
643            )));
644        }
645
646        loop {
647            match (*self).get_mut().write(buf) {
648                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
649                res => return Poll::Ready(res),
650            }
651            ready!(self.poll_writable(cx))?;
652        }
653    }
654
655    fn close(&self) -> io::Result<()> {
656        self.get_ref().shutdown(std::net::Shutdown::Both)
657    }
658}
659
660#[cfg(feature = "tokio-vsock")]
661impl Socket for tokio_vsock::VsockStream {
662    fn can_pass_unix_fd(&self) -> bool {
663        false
664    }
665
666    fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
667        use tokio::io::{AsyncRead, ReadBuf};
668
669        let mut read_buf = ReadBuf::new(buf);
670        Pin::new(self).poll_read(cx, &mut read_buf).map(|res| {
671            res.map(|_| {
672                let ret = read_buf.filled().len();
673                #[cfg(unix)]
674                let ret = (ret, vec![]);
675
676                ret
677            })
678        })
679    }
680
681    fn poll_sendmsg(
682        &mut self,
683        cx: &mut Context<'_>,
684        buf: &[u8],
685        #[cfg(unix)] fds: &[RawFd],
686    ) -> Poll<io::Result<usize>> {
687        use tokio::io::AsyncWrite;
688
689        #[cfg(unix)]
690        if !fds.is_empty() {
691            return Poll::Ready(Err(io::Error::new(
692                io::ErrorKind::InvalidInput,
693                "fds cannot be sent with a tcp stream",
694            )));
695        }
696
697        Pin::new(self).poll_write(cx, buf)
698    }
699
700    fn close(&self) -> io::Result<()> {
701        self.shutdown(std::net::Shutdown::Both)
702    }
703}