x11rb/rust_connection/
stream.rs

1use rustix::fd::{AsFd, BorrowedFd};
2use std::io::{IoSlice, Result};
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::io::{AsRawFd, IntoRawFd, OwnedFd, RawFd};
6#[cfg(unix)]
7use std::os::unix::net::UnixStream;
8#[cfg(windows)]
9use std::os::windows::io::{
10    AsRawSocket, AsSocket, BorrowedSocket, IntoRawSocket, OwnedSocket, RawSocket,
11};
12
13use crate::utils::RawFdContainer;
14use x11rb_protocol::parse_display::ConnectAddress;
15use x11rb_protocol::xauth::Family;
16
17/// The kind of operation that one want to poll for.
18#[derive(Debug, Clone, Copy)]
19pub enum PollMode {
20    /// Check if the stream is readable, i.e. there is pending data to be read.
21    Readable,
22
23    /// Check if the stream is writable, i.e. some data could be successfully written to it.
24    Writable,
25
26    /// Check for both readability and writability.
27    ReadAndWritable,
28}
29
30impl PollMode {
31    /// Does this poll mode include readability?
32    pub fn readable(self) -> bool {
33        match self {
34            PollMode::Readable | PollMode::ReadAndWritable => true,
35            PollMode::Writable => false,
36        }
37    }
38
39    /// Does this poll mode include writability?
40    pub fn writable(self) -> bool {
41        match self {
42            PollMode::Writable | PollMode::ReadAndWritable => true,
43            PollMode::Readable => false,
44        }
45    }
46}
47
48/// A trait used to implement the raw communication with the X11 server.
49///
50/// None of the functions of this trait shall return [`std::io::ErrorKind::Interrupted`].
51/// If a system call fails with this error, the implementation should try again.
52pub trait Stream {
53    /// Waits for level-triggered read and/or write events on the stream.
54    ///
55    /// This function does not return what caused it to complete the poll.
56    /// Instead, callers should try to read or write and check for
57    /// [`std::io::ErrorKind::WouldBlock`].
58    ///
59    /// This function is allowed to spuriously return even if the stream
60    /// is neither readable nor writable. However, it shall not do it
61    /// continuously, which would cause a 100% CPU usage.
62    ///
63    /// # Multithreading
64    ///
65    /// If `Self` is `Send + Sync` and `poll` is used concurrently from more than
66    /// one thread, all threads should wake when the stream becomes readable (when
67    /// `read` is `true`) or writable (when `write` is `true`).
68    fn poll(&self, mode: PollMode) -> Result<()>;
69
70    /// Read some bytes and FDs from this reader without blocking, returning how many bytes
71    /// were read.
72    ///
73    /// This function works like [`std::io::Read::read`], but also supports the reception of file
74    /// descriptors. Any received file descriptors are appended to the given `fd_storage`.
75    /// Whereas implementation of [`std::io::Read::read`] are allowed to block or not to block,
76    /// this method shall never block and return `ErrorKind::WouldBlock` if needed.
77    ///
78    /// This function does not guarantee that all file descriptors were sent together with the data
79    /// with which they are received. However, file descriptors may not be received later than the
80    /// data that was sent at the same time. Instead, file descriptors may only be received
81    /// earlier.
82    ///
83    /// # Multithreading
84    ///
85    /// If `Self` is `Send + Sync` and `read` is used concurrently from more than one thread:
86    ///
87    /// * Both the data and the file descriptors shall be read in order, but possibly
88    ///   interleaved across threads.
89    /// * Neither the data nor the file descriptors shall be duplicated.
90    /// * The returned value shall always be the actual number of bytes read into `buf`.
91    fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize>;
92
93    /// Write a buffer and some FDs into this writer without blocking, returning how many
94    /// bytes were written.
95    ///
96    /// This function works like [`std::io::Write::write`], but also supports sending file
97    /// descriptors. The `fds` argument contains the file descriptors to send. The order of file
98    /// descriptors is maintained. Whereas implementation of [`std::io::Write::write`] are
99    /// allowed to block or not to block, this function must never block and return
100    /// `ErrorKind::WouldBlock` if needed.
101    ///
102    /// This function does not guarantee that all file descriptors are sent together with the data.
103    /// Any file descriptors that were sent are removed from the beginning of the given `Vec`.
104    ///
105    /// There is no guarantee that the given file descriptors are received together with the given
106    /// data. File descriptors might be received earlier than their corresponding data. It is not
107    /// allowed for file descriptors to be received later than the bytes that were sent at the same
108    /// time.
109    ///
110    /// # Multithreading
111    ///
112    /// If `Self` is `Send + Sync` and `write` is used concurrently from more than one thread:
113    ///
114    /// * Both the data and the file descriptors shall be written in order, but possibly
115    ///   interleaved across threads.
116    /// * Neither the data nor the file descriptors shall be duplicated.
117    /// * The returned value shall always be the actual number of bytes written from `buf`.
118    fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize>;
119
120    /// Like `write`, except that it writes from a slice of buffers. Like `write`, this
121    /// method must never block.
122    ///
123    /// This method must behave as a call to `write` with the buffers concatenated would.
124    ///
125    /// The default implementation calls `write` with the first nonempty buffer provided.
126    ///
127    /// # Multithreading
128    ///
129    /// Same as `write`.
130    fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
131        for buf in bufs {
132            if !buf.is_empty() {
133                return self.write(buf, fds);
134            }
135        }
136        Ok(0)
137    }
138}
139
140/// A wrapper around a `TcpStream` or `UnixStream`.
141///
142/// Use by default in `RustConnection` as stream.
143#[derive(Debug)]
144pub struct DefaultStream {
145    inner: DefaultStreamInner,
146}
147
148#[cfg(unix)]
149type DefaultStreamInner = RawFdContainer;
150
151#[cfg(not(unix))]
152type DefaultStreamInner = TcpStream;
153
154/// The address of a peer in a format suitable for xauth.
155///
156/// These values can be directly given to [`x11rb_protocol::xauth::get_auth`].
157type PeerAddr = (Family, Vec<u8>);
158
159impl DefaultStream {
160    /// Try to connect to the X11 server described by the given arguments.
161    pub fn connect(addr: &ConnectAddress<'_>) -> Result<(Self, PeerAddr)> {
162        match addr {
163            ConnectAddress::Hostname(host, port) => {
164                // connect over TCP
165                let stream = TcpStream::connect((*host, *port))?;
166                Self::from_tcp_stream(stream)
167            }
168            #[cfg(unix)]
169            ConnectAddress::Socket(path) => {
170                // Try abstract unix socket first. If that fails, fall back to normal unix socket
171                #[cfg(any(target_os = "linux", target_os = "android"))]
172                if let Ok(stream) = connect_abstract_unix_stream(path.as_bytes()) {
173                    // TODO: Does it make sense to add a constructor similar to from_unix_stream()?
174                    // If this is done: Move the set_nonblocking() from
175                    // connect_abstract_unix_stream() to that new function.
176                    let stream = DefaultStream { inner: stream };
177                    return Ok((stream, peer_addr::local()));
178                }
179
180                // connect over Unix domain socket
181                let stream = UnixStream::connect(path)?;
182                Self::from_unix_stream(stream)
183            }
184            #[cfg(not(unix))]
185            ConnectAddress::Socket(_) => {
186                // Unix domain sockets are not supported on Windows
187                Err(std::io::Error::new(
188                    std::io::ErrorKind::Other,
189                    "Unix domain sockets are not supported on Windows",
190                ))
191            }
192            _ => Err(std::io::Error::new(
193                std::io::ErrorKind::Other,
194                "The given address family is not implemented",
195            )),
196        }
197    }
198
199    /// Creates a new `Stream` from an already connected `TcpStream`.
200    ///
201    /// The stream will be set in non-blocking mode.
202    ///
203    /// This returns the peer address in a format suitable for [`x11rb_protocol::xauth::get_auth`].
204    pub fn from_tcp_stream(stream: TcpStream) -> Result<(Self, PeerAddr)> {
205        let peer_addr = peer_addr::tcp(&stream.peer_addr()?);
206        stream.set_nonblocking(true)?;
207        let result = Self {
208            inner: stream.into(),
209        };
210        Ok((result, peer_addr))
211    }
212
213    /// Creates a new `Stream` from an already connected `UnixStream`.
214    ///
215    /// The stream will be set in non-blocking mode.
216    ///
217    /// This returns the peer address in a format suitable for [`x11rb_protocol::xauth::get_auth`].
218    #[cfg(unix)]
219    pub fn from_unix_stream(stream: UnixStream) -> Result<(Self, PeerAddr)> {
220        stream.set_nonblocking(true)?;
221        let result = Self {
222            inner: stream.into(),
223        };
224        Ok((result, peer_addr::local()))
225    }
226
227    fn as_fd(&self) -> BorrowedFd<'_> {
228        self.inner.as_fd()
229    }
230}
231
232#[cfg(unix)]
233impl AsRawFd for DefaultStream {
234    fn as_raw_fd(&self) -> RawFd {
235        self.inner.as_raw_fd()
236    }
237}
238
239#[cfg(unix)]
240impl AsFd for DefaultStream {
241    fn as_fd(&self) -> BorrowedFd<'_> {
242        self.inner.as_fd()
243    }
244}
245
246#[cfg(unix)]
247impl IntoRawFd for DefaultStream {
248    fn into_raw_fd(self) -> RawFd {
249        self.inner.into_raw_fd()
250    }
251}
252
253#[cfg(unix)]
254impl From<DefaultStream> for OwnedFd {
255    fn from(stream: DefaultStream) -> Self {
256        stream.inner
257    }
258}
259
260#[cfg(windows)]
261impl AsRawSocket for DefaultStream {
262    fn as_raw_socket(&self) -> RawSocket {
263        self.inner.as_raw_socket()
264    }
265}
266
267#[cfg(windows)]
268impl AsSocket for DefaultStream {
269    fn as_socket(&self) -> BorrowedSocket<'_> {
270        self.inner.as_socket()
271    }
272}
273
274#[cfg(windows)]
275impl IntoRawSocket for DefaultStream {
276    fn into_raw_socket(self) -> RawSocket {
277        self.inner.into_raw_socket()
278    }
279}
280
281#[cfg(windows)]
282impl From<DefaultStream> for OwnedSocket {
283    fn from(stream: DefaultStream) -> Self {
284        stream.inner.into()
285    }
286}
287
288#[cfg(unix)]
289fn do_write(
290    stream: &DefaultStream,
291    bufs: &[IoSlice<'_>],
292    fds: &mut Vec<RawFdContainer>,
293) -> Result<usize> {
294    use rustix::io::Errno;
295    use rustix::net::{sendmsg, SendAncillaryBuffer, SendAncillaryMessage, SendFlags};
296    use std::mem::MaybeUninit;
297
298    fn sendmsg_wrapper(
299        fd: BorrowedFd<'_>,
300        iov: &[IoSlice<'_>],
301        cmsgs: &mut SendAncillaryBuffer<'_, '_, '_>,
302        flags: SendFlags,
303    ) -> Result<usize> {
304        loop {
305            match sendmsg(fd, iov, cmsgs, flags) {
306                Ok(n) => return Ok(n),
307                // try again
308                Err(Errno::INTR) => {}
309                Err(e) => return Err(e.into()),
310            }
311        }
312    }
313
314    let fd = stream.as_fd();
315
316    let res = if !fds.is_empty() {
317        let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>();
318        let rights = SendAncillaryMessage::ScmRights(&fds);
319
320        let mut cmsg_space = vec![MaybeUninit::uninit(); rights.size()];
321        let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
322        assert!(cmsg_buffer.push(rights));
323
324        sendmsg_wrapper(fd, bufs, &mut cmsg_buffer, SendFlags::empty())?
325    } else {
326        sendmsg_wrapper(fd, bufs, &mut Default::default(), SendFlags::empty())?
327    };
328
329    // We successfully sent all FDs
330    fds.clear();
331
332    Ok(res)
333}
334
335impl Stream for DefaultStream {
336    fn poll(&self, mode: PollMode) -> Result<()> {
337        use rustix::event::{poll, PollFd, PollFlags};
338        use rustix::io::Errno;
339
340        let mut poll_flags = PollFlags::empty();
341        if mode.readable() {
342            poll_flags |= PollFlags::IN;
343        }
344        if mode.writable() {
345            poll_flags |= PollFlags::OUT;
346        }
347        let fd = self.as_fd();
348        let mut poll_fds = [PollFd::from_borrowed_fd(fd, poll_flags)];
349        loop {
350            match poll(&mut poll_fds, None) {
351                Ok(_) => break,
352                Err(Errno::INTR) => {}
353                Err(e) => return Err(e.into()),
354            }
355        }
356        // Let the errors (POLLERR) be handled when trying to read or write.
357        Ok(())
358    }
359
360    fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize> {
361        #[cfg(unix)]
362        {
363            use rustix::io::Errno;
364            use rustix::net::{recvmsg, RecvAncillaryBuffer, RecvAncillaryMessage};
365            use std::io::IoSliceMut;
366            use std::mem::MaybeUninit;
367
368            // 1024 bytes on the stack should be enough for more file descriptors than the X server will ever
369            // send, as well as the header for the ancillary data. If you can find a case where this can
370            // overflow with an actual production X11 server, I'll buy you a steak dinner.
371            let mut cmsg = [MaybeUninit::uninit(); 1024];
372            let mut iov = [IoSliceMut::new(buf)];
373            let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg);
374
375            let fd = self.as_fd();
376            let msg = loop {
377                match recvmsg(fd, &mut iov, &mut cmsg_buffer, recvmsg::flags()) {
378                    Ok(msg) => break msg,
379                    // try again
380                    Err(Errno::INTR) => {}
381                    Err(e) => return Err(e.into()),
382                }
383            };
384
385            let fds_received = cmsg_buffer
386                .drain()
387                .filter_map(|cmsg| match cmsg {
388                    RecvAncillaryMessage::ScmRights(r) => Some(r),
389                    _ => None,
390                })
391                .flatten();
392
393            let mut cloexec_error = Ok(());
394            fd_storage.extend(recvmsg::after_recvmsg(fds_received, &mut cloexec_error));
395            cloexec_error?;
396
397            Ok(msg.bytes)
398        }
399        #[cfg(not(unix))]
400        {
401            use std::io::Read;
402            // No FDs are read, so nothing needs to be done with fd_storage
403            let _ = fd_storage;
404            loop {
405                // Use `impl Read for &TcpStream` to avoid needing a mutable `TcpStream`.
406                match (&mut &self.inner).read(buf) {
407                    Ok(n) => return Ok(n),
408                    // try again
409                    Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
410                    Err(e) => return Err(e),
411                }
412            }
413        }
414    }
415
416    fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
417        #[cfg(unix)]
418        {
419            do_write(self, &[IoSlice::new(buf)], fds)
420        }
421        #[cfg(not(unix))]
422        {
423            use std::io::{Error, ErrorKind, Write};
424            if !fds.is_empty() {
425                return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
426            }
427            loop {
428                // Use `impl Write for &TcpStream` to avoid needing a mutable `TcpStream`.
429                match (&mut &self.inner).write(buf) {
430                    Ok(n) => return Ok(n),
431                    // try again
432                    Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
433                    Err(e) => return Err(e),
434                }
435            }
436        }
437    }
438
439    fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
440        #[cfg(unix)]
441        {
442            do_write(self, bufs, fds)
443        }
444        #[cfg(not(unix))]
445        {
446            use std::io::{Error, ErrorKind, Write};
447            if !fds.is_empty() {
448                return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
449            }
450            loop {
451                // Use `impl Write for &TcpStream` to avoid needing a mutable `TcpStream`.
452                match (&mut &self.inner).write_vectored(bufs) {
453                    Ok(n) => return Ok(n),
454                    // try again
455                    Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
456                    Err(e) => return Err(e),
457                }
458            }
459        }
460    }
461}
462
463#[cfg(any(target_os = "linux", target_os = "android"))]
464fn connect_abstract_unix_stream(
465    path: &[u8],
466) -> std::result::Result<RawFdContainer, rustix::io::Errno> {
467    use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags};
468    use rustix::net::{
469        connect, socket_with, AddressFamily, SocketAddrUnix, SocketFlags, SocketType,
470    };
471
472    let socket = socket_with(
473        AddressFamily::UNIX,
474        SocketType::STREAM,
475        SocketFlags::CLOEXEC,
476        None,
477    )?;
478
479    connect(&socket, &SocketAddrUnix::new_abstract_name(path)?)?;
480
481    // Make the FD non-blocking
482    fcntl_setfl(&socket, fcntl_getfl(&socket)? | OFlags::NONBLOCK)?;
483
484    Ok(socket)
485}
486
487/// Helper code to make sure that received FDs are marked as CLOEXEC
488#[cfg(any(
489    target_os = "android",
490    target_os = "dragonfly",
491    target_os = "freebsd",
492    target_os = "linux",
493    target_os = "netbsd",
494    target_os = "openbsd"
495))]
496mod recvmsg {
497    use super::RawFdContainer;
498    use rustix::net::RecvFlags;
499
500    pub(crate) fn flags() -> RecvFlags {
501        RecvFlags::CMSG_CLOEXEC
502    }
503
504    pub(crate) fn after_recvmsg<'a>(
505        fds: impl Iterator<Item = RawFdContainer> + 'a,
506        _cloexec_error: &'a mut Result<(), rustix::io::Errno>,
507    ) -> impl Iterator<Item = RawFdContainer> + 'a {
508        fds
509    }
510}
511
512/// Helper code to make sure that received FDs are marked as CLOEXEC
513#[cfg(all(
514    unix,
515    not(any(
516        target_os = "android",
517        target_os = "dragonfly",
518        target_os = "freebsd",
519        target_os = "linux",
520        target_os = "netbsd",
521        target_os = "openbsd"
522    ))
523))]
524mod recvmsg {
525    use super::RawFdContainer;
526    use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
527    use rustix::net::RecvFlags;
528
529    pub(crate) fn flags() -> RecvFlags {
530        RecvFlags::empty()
531    }
532
533    pub(crate) fn after_recvmsg<'a>(
534        fds: impl Iterator<Item = RawFdContainer> + 'a,
535        cloexec_error: &'a mut rustix::io::Result<()>,
536    ) -> impl Iterator<Item = RawFdContainer> + 'a {
537        fds.map(move |fd| {
538            if let Err(e) =
539                fcntl_getfd(&fd).and_then(|flags| fcntl_setfd(&fd, flags | FdFlags::CLOEXEC))
540            {
541                *cloexec_error = Err(e);
542            }
543            fd
544        })
545    }
546}
547
548mod peer_addr {
549    use super::{Family, PeerAddr};
550    use std::net::{Ipv4Addr, SocketAddr};
551
552    // Get xauth information representing a local connection
553    pub(super) fn local() -> PeerAddr {
554        let hostname = gethostname::gethostname()
555            .to_str()
556            .map_or_else(Vec::new, |s| s.as_bytes().to_vec());
557        (Family::LOCAL, hostname)
558    }
559
560    // Get xauth information representing a TCP connection to the given address
561    pub(super) fn tcp(addr: &SocketAddr) -> PeerAddr {
562        let ip = match addr {
563            SocketAddr::V4(addr) => *addr.ip(),
564            SocketAddr::V6(addr) => {
565                let ip = addr.ip();
566                if ip.is_loopback() {
567                    // This is a local connection.
568                    // Use LOCALHOST to cause a fall-through in the code below.
569                    Ipv4Addr::LOCALHOST
570                } else if let Some(ip) = ip.to_ipv4() {
571                    // Let the ipv4 code below handle this
572                    ip
573                } else {
574                    // Okay, this is really a v6 address
575                    return (Family::INTERNET6, ip.octets().to_vec());
576                }
577            }
578        };
579
580        // Handle the v4 address
581        if ip.is_loopback() {
582            local()
583        } else {
584            (Family::INTERNET, ip.octets().to_vec())
585        }
586    }
587}