1use rustix::fd::{AsFd, BorrowedFd};
2use std::io::{IoSlice, Result};
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::io::{AsRawFd, IntoRawFd, OwnedFd, RawFd};
6#[cfg(unix)]
7use std::os::unix::net::UnixStream;
8#[cfg(windows)]
9use std::os::windows::io::{
10 AsRawSocket, AsSocket, BorrowedSocket, IntoRawSocket, OwnedSocket, RawSocket,
11};
12
13use crate::utils::RawFdContainer;
14use x11rb_protocol::parse_display::ConnectAddress;
15use x11rb_protocol::xauth::Family;
16
17#[derive(Debug, Clone, Copy)]
19pub enum PollMode {
20 Readable,
22
23 Writable,
25
26 ReadAndWritable,
28}
29
30impl PollMode {
31 pub fn readable(self) -> bool {
33 match self {
34 PollMode::Readable | PollMode::ReadAndWritable => true,
35 PollMode::Writable => false,
36 }
37 }
38
39 pub fn writable(self) -> bool {
41 match self {
42 PollMode::Writable | PollMode::ReadAndWritable => true,
43 PollMode::Readable => false,
44 }
45 }
46}
47
48pub trait Stream {
53 fn poll(&self, mode: PollMode) -> Result<()>;
69
70 fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize>;
92
93 fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize>;
119
120 fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
131 for buf in bufs {
132 if !buf.is_empty() {
133 return self.write(buf, fds);
134 }
135 }
136 Ok(0)
137 }
138}
139
140#[derive(Debug)]
144pub struct DefaultStream {
145 inner: DefaultStreamInner,
146}
147
148#[cfg(unix)]
149type DefaultStreamInner = RawFdContainer;
150
151#[cfg(not(unix))]
152type DefaultStreamInner = TcpStream;
153
154type PeerAddr = (Family, Vec<u8>);
158
159impl DefaultStream {
160 pub fn connect(addr: &ConnectAddress<'_>) -> Result<(Self, PeerAddr)> {
162 match addr {
163 ConnectAddress::Hostname(host, port) => {
164 let stream = TcpStream::connect((*host, *port))?;
166 Self::from_tcp_stream(stream)
167 }
168 #[cfg(unix)]
169 ConnectAddress::Socket(path) => {
170 #[cfg(any(target_os = "linux", target_os = "android"))]
172 if let Ok(stream) = connect_abstract_unix_stream(path.as_bytes()) {
173 let stream = DefaultStream { inner: stream };
177 return Ok((stream, peer_addr::local()));
178 }
179
180 let stream = UnixStream::connect(path)?;
182 Self::from_unix_stream(stream)
183 }
184 #[cfg(not(unix))]
185 ConnectAddress::Socket(_) => {
186 Err(std::io::Error::new(
188 std::io::ErrorKind::Other,
189 "Unix domain sockets are not supported on Windows",
190 ))
191 }
192 _ => Err(std::io::Error::new(
193 std::io::ErrorKind::Other,
194 "The given address family is not implemented",
195 )),
196 }
197 }
198
199 pub fn from_tcp_stream(stream: TcpStream) -> Result<(Self, PeerAddr)> {
205 let peer_addr = peer_addr::tcp(&stream.peer_addr()?);
206 stream.set_nonblocking(true)?;
207 let result = Self {
208 inner: stream.into(),
209 };
210 Ok((result, peer_addr))
211 }
212
213 #[cfg(unix)]
219 pub fn from_unix_stream(stream: UnixStream) -> Result<(Self, PeerAddr)> {
220 stream.set_nonblocking(true)?;
221 let result = Self {
222 inner: stream.into(),
223 };
224 Ok((result, peer_addr::local()))
225 }
226
227 fn as_fd(&self) -> BorrowedFd<'_> {
228 self.inner.as_fd()
229 }
230}
231
232#[cfg(unix)]
233impl AsRawFd for DefaultStream {
234 fn as_raw_fd(&self) -> RawFd {
235 self.inner.as_raw_fd()
236 }
237}
238
239#[cfg(unix)]
240impl AsFd for DefaultStream {
241 fn as_fd(&self) -> BorrowedFd<'_> {
242 self.inner.as_fd()
243 }
244}
245
246#[cfg(unix)]
247impl IntoRawFd for DefaultStream {
248 fn into_raw_fd(self) -> RawFd {
249 self.inner.into_raw_fd()
250 }
251}
252
253#[cfg(unix)]
254impl From<DefaultStream> for OwnedFd {
255 fn from(stream: DefaultStream) -> Self {
256 stream.inner
257 }
258}
259
260#[cfg(windows)]
261impl AsRawSocket for DefaultStream {
262 fn as_raw_socket(&self) -> RawSocket {
263 self.inner.as_raw_socket()
264 }
265}
266
267#[cfg(windows)]
268impl AsSocket for DefaultStream {
269 fn as_socket(&self) -> BorrowedSocket<'_> {
270 self.inner.as_socket()
271 }
272}
273
274#[cfg(windows)]
275impl IntoRawSocket for DefaultStream {
276 fn into_raw_socket(self) -> RawSocket {
277 self.inner.into_raw_socket()
278 }
279}
280
281#[cfg(windows)]
282impl From<DefaultStream> for OwnedSocket {
283 fn from(stream: DefaultStream) -> Self {
284 stream.inner.into()
285 }
286}
287
288#[cfg(unix)]
289fn do_write(
290 stream: &DefaultStream,
291 bufs: &[IoSlice<'_>],
292 fds: &mut Vec<RawFdContainer>,
293) -> Result<usize> {
294 use rustix::io::Errno;
295 use rustix::net::{sendmsg, SendAncillaryBuffer, SendAncillaryMessage, SendFlags};
296 use std::mem::MaybeUninit;
297
298 fn sendmsg_wrapper(
299 fd: BorrowedFd<'_>,
300 iov: &[IoSlice<'_>],
301 cmsgs: &mut SendAncillaryBuffer<'_, '_, '_>,
302 flags: SendFlags,
303 ) -> Result<usize> {
304 loop {
305 match sendmsg(fd, iov, cmsgs, flags) {
306 Ok(n) => return Ok(n),
307 Err(Errno::INTR) => {}
309 Err(e) => return Err(e.into()),
310 }
311 }
312 }
313
314 let fd = stream.as_fd();
315
316 let res = if !fds.is_empty() {
317 let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>();
318 let rights = SendAncillaryMessage::ScmRights(&fds);
319
320 let mut cmsg_space = vec![MaybeUninit::uninit(); rights.size()];
321 let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
322 assert!(cmsg_buffer.push(rights));
323
324 sendmsg_wrapper(fd, bufs, &mut cmsg_buffer, SendFlags::empty())?
325 } else {
326 sendmsg_wrapper(fd, bufs, &mut Default::default(), SendFlags::empty())?
327 };
328
329 fds.clear();
331
332 Ok(res)
333}
334
335impl Stream for DefaultStream {
336 fn poll(&self, mode: PollMode) -> Result<()> {
337 use rustix::event::{poll, PollFd, PollFlags};
338 use rustix::io::Errno;
339
340 let mut poll_flags = PollFlags::empty();
341 if mode.readable() {
342 poll_flags |= PollFlags::IN;
343 }
344 if mode.writable() {
345 poll_flags |= PollFlags::OUT;
346 }
347 let fd = self.as_fd();
348 let mut poll_fds = [PollFd::from_borrowed_fd(fd, poll_flags)];
349 loop {
350 match poll(&mut poll_fds, None) {
351 Ok(_) => break,
352 Err(Errno::INTR) => {}
353 Err(e) => return Err(e.into()),
354 }
355 }
356 Ok(())
358 }
359
360 fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize> {
361 #[cfg(unix)]
362 {
363 use rustix::io::Errno;
364 use rustix::net::{recvmsg, RecvAncillaryBuffer, RecvAncillaryMessage};
365 use std::io::IoSliceMut;
366 use std::mem::MaybeUninit;
367
368 let mut cmsg = [MaybeUninit::uninit(); 1024];
372 let mut iov = [IoSliceMut::new(buf)];
373 let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg);
374
375 let fd = self.as_fd();
376 let msg = loop {
377 match recvmsg(fd, &mut iov, &mut cmsg_buffer, recvmsg::flags()) {
378 Ok(msg) => break msg,
379 Err(Errno::INTR) => {}
381 Err(e) => return Err(e.into()),
382 }
383 };
384
385 let fds_received = cmsg_buffer
386 .drain()
387 .filter_map(|cmsg| match cmsg {
388 RecvAncillaryMessage::ScmRights(r) => Some(r),
389 _ => None,
390 })
391 .flatten();
392
393 let mut cloexec_error = Ok(());
394 fd_storage.extend(recvmsg::after_recvmsg(fds_received, &mut cloexec_error));
395 cloexec_error?;
396
397 Ok(msg.bytes)
398 }
399 #[cfg(not(unix))]
400 {
401 use std::io::Read;
402 let _ = fd_storage;
404 loop {
405 match (&mut &self.inner).read(buf) {
407 Ok(n) => return Ok(n),
408 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
410 Err(e) => return Err(e),
411 }
412 }
413 }
414 }
415
416 fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
417 #[cfg(unix)]
418 {
419 do_write(self, &[IoSlice::new(buf)], fds)
420 }
421 #[cfg(not(unix))]
422 {
423 use std::io::{Error, ErrorKind, Write};
424 if !fds.is_empty() {
425 return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
426 }
427 loop {
428 match (&mut &self.inner).write(buf) {
430 Ok(n) => return Ok(n),
431 Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
433 Err(e) => return Err(e),
434 }
435 }
436 }
437 }
438
439 fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
440 #[cfg(unix)]
441 {
442 do_write(self, bufs, fds)
443 }
444 #[cfg(not(unix))]
445 {
446 use std::io::{Error, ErrorKind, Write};
447 if !fds.is_empty() {
448 return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
449 }
450 loop {
451 match (&mut &self.inner).write_vectored(bufs) {
453 Ok(n) => return Ok(n),
454 Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
456 Err(e) => return Err(e),
457 }
458 }
459 }
460 }
461}
462
463#[cfg(any(target_os = "linux", target_os = "android"))]
464fn connect_abstract_unix_stream(
465 path: &[u8],
466) -> std::result::Result<RawFdContainer, rustix::io::Errno> {
467 use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags};
468 use rustix::net::{
469 connect, socket_with, AddressFamily, SocketAddrUnix, SocketFlags, SocketType,
470 };
471
472 let socket = socket_with(
473 AddressFamily::UNIX,
474 SocketType::STREAM,
475 SocketFlags::CLOEXEC,
476 None,
477 )?;
478
479 connect(&socket, &SocketAddrUnix::new_abstract_name(path)?)?;
480
481 fcntl_setfl(&socket, fcntl_getfl(&socket)? | OFlags::NONBLOCK)?;
483
484 Ok(socket)
485}
486
487#[cfg(any(
489 target_os = "android",
490 target_os = "dragonfly",
491 target_os = "freebsd",
492 target_os = "linux",
493 target_os = "netbsd",
494 target_os = "openbsd"
495))]
496mod recvmsg {
497 use super::RawFdContainer;
498 use rustix::net::RecvFlags;
499
500 pub(crate) fn flags() -> RecvFlags {
501 RecvFlags::CMSG_CLOEXEC
502 }
503
504 pub(crate) fn after_recvmsg<'a>(
505 fds: impl Iterator<Item = RawFdContainer> + 'a,
506 _cloexec_error: &'a mut Result<(), rustix::io::Errno>,
507 ) -> impl Iterator<Item = RawFdContainer> + 'a {
508 fds
509 }
510}
511
512#[cfg(all(
514 unix,
515 not(any(
516 target_os = "android",
517 target_os = "dragonfly",
518 target_os = "freebsd",
519 target_os = "linux",
520 target_os = "netbsd",
521 target_os = "openbsd"
522 ))
523))]
524mod recvmsg {
525 use super::RawFdContainer;
526 use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
527 use rustix::net::RecvFlags;
528
529 pub(crate) fn flags() -> RecvFlags {
530 RecvFlags::empty()
531 }
532
533 pub(crate) fn after_recvmsg<'a>(
534 fds: impl Iterator<Item = RawFdContainer> + 'a,
535 cloexec_error: &'a mut rustix::io::Result<()>,
536 ) -> impl Iterator<Item = RawFdContainer> + 'a {
537 fds.map(move |fd| {
538 if let Err(e) =
539 fcntl_getfd(&fd).and_then(|flags| fcntl_setfd(&fd, flags | FdFlags::CLOEXEC))
540 {
541 *cloexec_error = Err(e);
542 }
543 fd
544 })
545 }
546}
547
548mod peer_addr {
549 use super::{Family, PeerAddr};
550 use std::net::{Ipv4Addr, SocketAddr};
551
552 pub(super) fn local() -> PeerAddr {
554 let hostname = gethostname::gethostname()
555 .to_str()
556 .map_or_else(Vec::new, |s| s.as_bytes().to_vec());
557 (Family::LOCAL, hostname)
558 }
559
560 pub(super) fn tcp(addr: &SocketAddr) -> PeerAddr {
562 let ip = match addr {
563 SocketAddr::V4(addr) => *addr.ip(),
564 SocketAddr::V6(addr) => {
565 let ip = addr.ip();
566 if ip.is_loopback() {
567 Ipv4Addr::LOCALHOST
570 } else if let Some(ip) = ip.to_ipv4() {
571 ip
573 } else {
574 return (Family::INTERNET6, ip.octets().to_vec());
576 }
577 }
578 };
579
580 if ip.is_loopback() {
582 local()
583 } else {
584 (Family::INTERNET, ip.octets().to_vec())
585 }
586 }
587}