zbus/connection/socket/
unix.rs

1#[cfg(not(feature = "tokio"))]
2use async_io::Async;
3use std::io;
4#[cfg(unix)]
5use std::os::unix::io::{AsRawFd, BorrowedFd, FromRawFd, RawFd};
6#[cfg(all(unix, not(feature = "tokio")))]
7use std::os::unix::net::UnixStream;
8#[cfg(not(feature = "tokio"))]
9use std::sync::Arc;
10#[cfg(unix)]
11use std::{
12    future::poll_fn,
13    io::{IoSlice, IoSliceMut},
14    os::fd::OwnedFd,
15    task::Poll,
16};
17#[cfg(all(windows, not(feature = "tokio")))]
18use uds_windows::UnixStream;
19
20#[cfg(unix)]
21use nix::{
22    cmsg_space,
23    sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, UnixAddr},
24};
25
26#[cfg(unix)]
27use crate::utils::FDS_MAX;
28
29#[cfg(all(unix, not(feature = "tokio")))]
30#[async_trait::async_trait]
31impl super::ReadHalf for Arc<Async<UnixStream>> {
32    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
33        poll_fn(|cx| {
34            let (len, fds) = loop {
35                match fd_recvmsg(self.as_raw_fd(), buf) {
36                    Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
37                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_readable(cx)
38                    {
39                        Poll::Pending => return Poll::Pending,
40                        Poll::Ready(res) => res?,
41                    },
42                    v => break v?,
43                }
44            };
45            Poll::Ready(Ok((len, fds)))
46        })
47        .await
48    }
49
50    /// Supports passing file descriptors.
51    fn can_pass_unix_fd(&self) -> bool {
52        true
53    }
54
55    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
56        get_unix_peer_creds(self).await
57    }
58}
59
60#[cfg(all(unix, not(feature = "tokio")))]
61#[async_trait::async_trait]
62impl super::WriteHalf for Arc<Async<UnixStream>> {
63    async fn sendmsg(
64        &mut self,
65        buffer: &[u8],
66        #[cfg(unix)] fds: &[BorrowedFd<'_>],
67    ) -> io::Result<usize> {
68        poll_fn(|cx| loop {
69            match fd_sendmsg(
70                self.as_raw_fd(),
71                buffer,
72                #[cfg(unix)]
73                fds,
74            ) {
75                Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
76                Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_writable(cx) {
77                    Poll::Pending => return Poll::Pending,
78                    Poll::Ready(res) => res?,
79                },
80                v => return Poll::Ready(v),
81            }
82        })
83        .await
84    }
85
86    async fn close(&mut self) -> io::Result<()> {
87        let stream = self.clone();
88        crate::Task::spawn_blocking(
89            move || stream.get_ref().shutdown(std::net::Shutdown::Both),
90            "close socket",
91        )
92        .await
93    }
94
95    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
96    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
97        send_zero_byte(self).await.map(Some)
98    }
99
100    /// Supports passing file descriptors.
101    fn can_pass_unix_fd(&self) -> bool {
102        true
103    }
104
105    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
106        super::ReadHalf::peer_credentials(self).await
107    }
108}
109
110#[cfg(all(unix, feature = "tokio"))]
111impl super::Socket for tokio::net::UnixStream {
112    type ReadHalf = tokio::net::unix::OwnedReadHalf;
113    type WriteHalf = tokio::net::unix::OwnedWriteHalf;
114
115    fn split(self) -> super::Split<Self::ReadHalf, Self::WriteHalf> {
116        let (read, write) = self.into_split();
117
118        super::Split { read, write }
119    }
120}
121
122#[cfg(all(unix, feature = "tokio"))]
123#[async_trait::async_trait]
124impl super::ReadHalf for tokio::net::unix::OwnedReadHalf {
125    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
126        let stream = self.as_ref();
127        poll_fn(|cx| {
128            loop {
129                match stream.try_io(tokio::io::Interest::READABLE, || {
130                    // We use own custom function for reading because we need to receive file
131                    // descriptors too.
132                    fd_recvmsg(stream.as_raw_fd(), buf)
133                }) {
134                    Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
135                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
136                        match stream.poll_read_ready(cx) {
137                            Poll::Pending => return Poll::Pending,
138                            Poll::Ready(res) => res?,
139                        }
140                    }
141                    v => return Poll::Ready(v),
142                }
143            }
144        })
145        .await
146    }
147
148    /// Supports passing file descriptors.
149    fn can_pass_unix_fd(&self) -> bool {
150        true
151    }
152
153    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
154        get_unix_peer_creds(self.as_ref()).await
155    }
156}
157
158#[cfg(all(unix, feature = "tokio"))]
159#[async_trait::async_trait]
160impl super::WriteHalf for tokio::net::unix::OwnedWriteHalf {
161    async fn sendmsg(
162        &mut self,
163        buffer: &[u8],
164        #[cfg(unix)] fds: &[BorrowedFd<'_>],
165    ) -> io::Result<usize> {
166        let stream = self.as_ref();
167        poll_fn(|cx| loop {
168            match stream.try_io(tokio::io::Interest::WRITABLE, || {
169                fd_sendmsg(
170                    stream.as_raw_fd(),
171                    buffer,
172                    #[cfg(unix)]
173                    fds,
174                )
175            }) {
176                Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
177                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
178                    match stream.poll_write_ready(cx) {
179                        Poll::Pending => return Poll::Pending,
180                        Poll::Ready(res) => res?,
181                    }
182                }
183                v => return Poll::Ready(v),
184            }
185        })
186        .await
187    }
188
189    async fn close(&mut self) -> io::Result<()> {
190        tokio::io::AsyncWriteExt::shutdown(self).await
191    }
192
193    #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
194    async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
195        send_zero_byte(self.as_ref()).await.map(Some)
196    }
197
198    /// Supports passing file descriptors.
199    fn can_pass_unix_fd(&self) -> bool {
200        true
201    }
202
203    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
204        get_unix_peer_creds(self.as_ref()).await
205    }
206}
207
208#[cfg(all(windows, not(feature = "tokio")))]
209#[async_trait::async_trait]
210impl super::ReadHalf for Arc<Async<UnixStream>> {
211    async fn recvmsg(&mut self, buf: &mut [u8]) -> super::RecvmsgResult {
212        match futures_lite::AsyncReadExt::read(&mut self.as_ref(), buf).await {
213            Err(e) => Err(e),
214            Ok(len) => {
215                #[cfg(unix)]
216                let ret = (len, vec![]);
217                #[cfg(not(unix))]
218                let ret = len;
219                Ok(ret)
220            }
221        }
222    }
223
224    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
225        let stream = self.clone();
226        crate::Task::spawn_blocking(
227            move || {
228                use crate::win32::{unix_stream_get_peer_pid, ProcessToken};
229
230                let pid = unix_stream_get_peer_pid(stream.get_ref())? as _;
231                let sid = ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
232                    .and_then(|process_token| process_token.sid())?;
233                Ok(crate::fdo::ConnectionCredentials::default()
234                    .set_process_id(pid)
235                    .set_windows_sid(sid))
236            },
237            "peer credentials",
238        )
239        .await
240    }
241}
242
243#[cfg(all(windows, not(feature = "tokio")))]
244#[async_trait::async_trait]
245impl super::WriteHalf for Arc<Async<UnixStream>> {
246    async fn sendmsg(
247        &mut self,
248        buf: &[u8],
249        #[cfg(unix)] _fds: &[BorrowedFd<'_>],
250    ) -> io::Result<usize> {
251        futures_lite::AsyncWriteExt::write(&mut self.as_ref(), buf).await
252    }
253
254    async fn close(&mut self) -> io::Result<()> {
255        let stream = self.clone();
256        crate::Task::spawn_blocking(
257            move || stream.get_ref().shutdown(std::net::Shutdown::Both),
258            "close socket",
259        )
260        .await
261    }
262
263    async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
264        super::ReadHalf::peer_credentials(self).await
265    }
266}
267
268#[cfg(unix)]
269fn fd_recvmsg(fd: RawFd, buffer: &mut [u8]) -> io::Result<(usize, Vec<OwnedFd>)> {
270    let mut iov = [IoSliceMut::new(buffer)];
271    let mut cmsgspace = cmsg_space!([RawFd; FDS_MAX]);
272
273    let msg = recvmsg::<UnixAddr>(fd, &mut iov, Some(&mut cmsgspace), MsgFlags::empty())?;
274    if msg.bytes == 0 {
275        return Err(io::Error::new(
276            io::ErrorKind::BrokenPipe,
277            "failed to read from socket",
278        ));
279    }
280    let mut fds = vec![];
281    for cmsg in msg.cmsgs()? {
282        #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
283        if let ControlMessageOwned::ScmCreds(_) = cmsg {
284            continue;
285        }
286        if let ControlMessageOwned::ScmRights(fd) = cmsg {
287            fds.extend(fd.iter().map(|&f| unsafe { OwnedFd::from_raw_fd(f) }));
288        } else {
289            return Err(io::Error::new(
290                io::ErrorKind::InvalidData,
291                "unexpected CMSG kind",
292            ));
293        }
294    }
295    Ok((msg.bytes, fds))
296}
297
298#[cfg(unix)]
299fn fd_sendmsg(fd: RawFd, buffer: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result<usize> {
300    // FIXME: Remove this conversion once nix supports BorrowedFd here.
301    //
302    // Tracking issue: https://github.com/nix-rust/nix/issues/1750
303    let fds: Vec<_> = fds.iter().map(|f| f.as_raw_fd()).collect();
304    let cmsg = if !fds.is_empty() {
305        vec![ControlMessage::ScmRights(&fds)]
306    } else {
307        vec![]
308    };
309    let iov = [IoSlice::new(buffer)];
310    match sendmsg::<UnixAddr>(fd, &iov, &cmsg, MsgFlags::empty(), None) {
311        // can it really happen?
312        Ok(0) => Err(io::Error::new(
313            io::ErrorKind::WriteZero,
314            "failed to write to buffer",
315        )),
316        Ok(n) => Ok(n),
317        Err(e) => Err(e.into()),
318    }
319}
320
321#[cfg(unix)]
322async fn get_unix_peer_creds(fd: &impl AsRawFd) -> io::Result<crate::fdo::ConnectionCredentials> {
323    let fd = fd.as_raw_fd();
324    // FIXME: Is it likely enough for sending of 1 byte to block, to justify a task (possibly
325    // launching a thread in turn)?
326    crate::Task::spawn_blocking(move || get_unix_peer_creds_blocking(fd), "peer credentials").await
327}
328
329#[cfg(unix)]
330fn get_unix_peer_creds_blocking(fd: RawFd) -> io::Result<crate::fdo::ConnectionCredentials> {
331    // TODO: get this BorrowedFd directly from get_unix_peer_creds(), but this requires a
332    // 'static lifetime due to the Task.
333    let fd = unsafe { BorrowedFd::borrow_raw(fd) };
334    let mut creds = crate::fdo::ConnectionCredentials::default();
335
336    #[cfg(any(target_os = "android", target_os = "linux"))]
337    {
338        use nix::{
339            sys::socket::{getsockopt, sockopt::PeerCredentials},
340            unistd::{getgrouplist, Gid, Uid, User},
341        };
342        use std::ffi::CString;
343        use tracing::debug;
344
345        let (uid, gid, pid) = {
346            let unix_creds = getsockopt(&fd, PeerCredentials)?;
347            (
348                Uid::from_raw(unix_creds.uid()),
349                Gid::from_raw(unix_creds.uid()),
350                unix_creds.pid() as u32,
351            )
352        };
353        creds = creds.set_unix_user_id(uid.as_raw()).set_process_id(pid);
354
355        // the dbus spec requires groups to be either absent or complete (primary + secondary
356        // groups)
357        let mut groups = User::from_uid(uid)
358            .map_err(|e| debug!("User lookup failed: {}", e))
359            .ok()
360            .flatten()
361            .map(|user| CString::new(user.name))
362            .transpose()?
363            .map(|user| getgrouplist(&user, gid))
364            .transpose()
365            .map_err(|e| debug!("Group lookup failed: {}", e))
366            .ok()
367            .flatten()
368            .unwrap_or(Vec::new());
369        // it also requires the groups to be numerically sorted
370        groups.sort_by_key(|gid| gid.as_raw());
371        for group in groups.iter() {
372            creds = creds.add_unix_group_id(group.as_raw());
373        }
374
375        #[cfg(target_os = "linux")]
376        {
377            use nix::{errno::Errno, sys::socket::sockopt::PeerPidfd};
378            use zvariant::OwnedFd;
379
380            match getsockopt(&fd, PeerPidfd) {
381                Err(Errno::ENOPROTOOPT) => (),
382                Ok(pidfd) => creds = creds.set_process_fd(OwnedFd::from(pidfd)),
383                Err(e) => return Err(e.into()),
384            };
385        }
386    }
387
388    #[cfg(any(
389        target_os = "macos",
390        target_os = "ios",
391        target_os = "freebsd",
392        target_os = "dragonfly",
393        target_os = "openbsd",
394        target_os = "netbsd"
395    ))]
396    {
397        let (uid, _gid) = nix::unistd::getpeereid(fd)?;
398        creds = creds.set_unix_user_id(uid.as_raw())
399
400        // FIXME: Handle pid fetching too
401    }
402
403    Ok(creds)
404}
405
406// Send 0 byte as a separate SCM_CREDS message.
407#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
408async fn send_zero_byte(fd: &impl AsRawFd) -> io::Result<usize> {
409    let fd = fd.as_raw_fd();
410    crate::Task::spawn_blocking(move || send_zero_byte_blocking(fd), "send zero byte").await
411}
412
413#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
414fn send_zero_byte_blocking(fd: RawFd) -> io::Result<usize> {
415    let iov = [std::io::IoSlice::new(b"\0")];
416    sendmsg::<()>(
417        fd,
418        &iov,
419        &[ControlMessage::ScmCreds],
420        MsgFlags::empty(),
421        None,
422    )
423    .map_err(|e| e.into())
424}