zbus/address/transport/
tcp.rs

1use super::encode_percents;
2use crate::{Error, Result};
3#[cfg(not(feature = "tokio"))]
4use async_io::Async;
5#[cfg(not(feature = "tokio"))]
6use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
7use std::{
8    collections::HashMap,
9    fmt::{Display, Formatter},
10    str::FromStr,
11};
12#[cfg(feature = "tokio")]
13use tokio::net::TcpStream;
14
15/// A TCP transport in a D-Bus address.
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub struct Tcp {
18    pub(super) host: String,
19    pub(super) bind: Option<String>,
20    pub(super) port: u16,
21    pub(super) family: Option<TcpTransportFamily>,
22    pub(super) nonce_file: Option<Vec<u8>>,
23}
24
25impl Tcp {
26    /// Create a new TCP transport with the given host and port.
27    pub fn new(host: &str, port: u16) -> Self {
28        Self {
29            host: host.to_owned(),
30            port,
31            bind: None,
32            family: None,
33            nonce_file: None,
34        }
35    }
36
37    /// Set the `tcp:` address `bind` value.
38    pub fn set_bind(mut self, bind: Option<String>) -> Self {
39        self.bind = bind;
40
41        self
42    }
43
44    /// Set the `tcp:` address `family` value.
45    pub fn set_family(mut self, family: Option<TcpTransportFamily>) -> Self {
46        self.family = family;
47
48        self
49    }
50
51    /// Set the `tcp:` address `noncefile` value.
52    pub fn set_nonce_file(mut self, nonce_file: Option<Vec<u8>>) -> Self {
53        self.nonce_file = nonce_file;
54
55        self
56    }
57
58    /// The `tcp:` address `host` value.
59    pub fn host(&self) -> &str {
60        &self.host
61    }
62
63    /// The `tcp:` address `bind` value.
64    pub fn bind(&self) -> Option<&str> {
65        self.bind.as_deref()
66    }
67
68    /// The `tcp:` address `port` value.
69    pub fn port(&self) -> u16 {
70        self.port
71    }
72
73    /// The `tcp:` address `family` value.
74    pub fn family(&self) -> Option<TcpTransportFamily> {
75        self.family
76    }
77
78    /// The nonce file path, if any.
79    pub fn nonce_file(&self) -> Option<&[u8]> {
80        self.nonce_file.as_deref()
81    }
82
83    /// Take ownership of the nonce file path, if any.
84    pub fn take_nonce_file(&mut self) -> Option<Vec<u8>> {
85        self.nonce_file.take()
86    }
87
88    pub(super) fn from_options(
89        opts: HashMap<&str, &str>,
90        nonce_tcp_required: bool,
91    ) -> Result<Self> {
92        let bind = None;
93        if opts.contains_key("bind") {
94            return Err(Error::Address("`bind` isn't yet supported".into()));
95        }
96
97        let host = opts
98            .get("host")
99            .ok_or_else(|| Error::Address("tcp address is missing `host`".into()))?
100            .to_string();
101        let port = opts
102            .get("port")
103            .ok_or_else(|| Error::Address("tcp address is missing `port`".into()))?;
104        let port = port
105            .parse::<u16>()
106            .map_err(|_| Error::Address("invalid tcp `port`".into()))?;
107        let family = opts
108            .get("family")
109            .map(|f| TcpTransportFamily::from_str(f))
110            .transpose()?;
111        let nonce_file = opts
112            .get("noncefile")
113            .map(|f| super::decode_percents(f))
114            .transpose()?;
115        if nonce_tcp_required && nonce_file.is_none() {
116            return Err(Error::Address(
117                "nonce-tcp address is missing `noncefile`".into(),
118            ));
119        }
120
121        Ok(Self {
122            host,
123            bind,
124            port,
125            family,
126            nonce_file,
127        })
128    }
129
130    #[cfg(not(feature = "tokio"))]
131    pub(super) async fn connect(self) -> Result<Async<TcpStream>> {
132        let addrs = crate::Task::spawn_blocking(
133            move || -> Result<Vec<SocketAddr>> {
134                let addrs = (self.host(), self.port()).to_socket_addrs()?.filter(|a| {
135                    if let Some(family) = self.family() {
136                        if family == TcpTransportFamily::Ipv4 {
137                            a.is_ipv4()
138                        } else {
139                            a.is_ipv6()
140                        }
141                    } else {
142                        true
143                    }
144                });
145                Ok(addrs.collect())
146            },
147            "connect tcp",
148        )
149        .await
150        .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}")))?;
151
152        // we could attempt connections in parallel?
153        let mut last_err = Error::Address("Failed to connect".into());
154        for addr in addrs {
155            match Async::<TcpStream>::connect(addr).await {
156                Ok(stream) => return Ok(stream),
157                Err(e) => last_err = e.into(),
158            }
159        }
160
161        Err(last_err)
162    }
163
164    #[cfg(feature = "tokio")]
165    pub(super) async fn connect(self) -> Result<TcpStream> {
166        TcpStream::connect((self.host(), self.port()))
167            .await
168            .map_err(|e| Error::InputOutput(e.into()))
169    }
170}
171
172impl Display for Tcp {
173    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
174        match self.nonce_file() {
175            Some(nonce_file) => {
176                f.write_str("nonce-tcp:noncefile=")?;
177                encode_percents(f, nonce_file)?;
178                f.write_str(",")?;
179            }
180            None => f.write_str("tcp:")?,
181        }
182        f.write_str("host=")?;
183
184        encode_percents(f, self.host().as_bytes())?;
185
186        write!(f, ",port={}", self.port())?;
187
188        if let Some(bind) = self.bind() {
189            f.write_str(",bind=")?;
190            encode_percents(f, bind.as_bytes())?;
191        }
192
193        if let Some(family) = self.family() {
194            write!(f, ",family={family}")?;
195        }
196
197        Ok(())
198    }
199}
200
201/// A `tcp:` address family.
202#[derive(Copy, Clone, Debug, PartialEq, Eq)]
203pub enum TcpTransportFamily {
204    Ipv4,
205    Ipv6,
206}
207
208impl FromStr for TcpTransportFamily {
209    type Err = Error;
210
211    fn from_str(family: &str) -> Result<Self> {
212        match family {
213            "ipv4" => Ok(Self::Ipv4),
214            "ipv6" => Ok(Self::Ipv6),
215            _ => Err(Error::Address(format!(
216                "invalid tcp address `family`: {family}"
217            ))),
218        }
219    }
220}
221
222impl Display for TcpTransportFamily {
223    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
224        match self {
225            Self::Ipv4 => write!(f, "ipv4"),
226            Self::Ipv6 => write!(f, "ipv6"),
227        }
228    }
229}