zbus/address/transport/
tcp.rs1use super::encode_percents;
2use crate::{Error, Result};
3#[cfg(not(feature = "tokio"))]
4use async_io::Async;
5#[cfg(not(feature = "tokio"))]
6use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
7use std::{
8 collections::HashMap,
9 fmt::{Display, Formatter},
10 str::FromStr,
11};
12#[cfg(feature = "tokio")]
13use tokio::net::TcpStream;
14
15#[derive(Clone, Debug, PartialEq, Eq)]
17pub struct Tcp {
18 pub(super) host: String,
19 pub(super) bind: Option<String>,
20 pub(super) port: u16,
21 pub(super) family: Option<TcpTransportFamily>,
22 pub(super) nonce_file: Option<Vec<u8>>,
23}
24
25impl Tcp {
26 pub fn new(host: &str, port: u16) -> Self {
28 Self {
29 host: host.to_owned(),
30 port,
31 bind: None,
32 family: None,
33 nonce_file: None,
34 }
35 }
36
37 pub fn set_bind(mut self, bind: Option<String>) -> Self {
39 self.bind = bind;
40
41 self
42 }
43
44 pub fn set_family(mut self, family: Option<TcpTransportFamily>) -> Self {
46 self.family = family;
47
48 self
49 }
50
51 pub fn set_nonce_file(mut self, nonce_file: Option<Vec<u8>>) -> Self {
53 self.nonce_file = nonce_file;
54
55 self
56 }
57
58 pub fn host(&self) -> &str {
60 &self.host
61 }
62
63 pub fn bind(&self) -> Option<&str> {
65 self.bind.as_deref()
66 }
67
68 pub fn port(&self) -> u16 {
70 self.port
71 }
72
73 pub fn family(&self) -> Option<TcpTransportFamily> {
75 self.family
76 }
77
78 pub fn nonce_file(&self) -> Option<&[u8]> {
80 self.nonce_file.as_deref()
81 }
82
83 pub fn take_nonce_file(&mut self) -> Option<Vec<u8>> {
85 self.nonce_file.take()
86 }
87
88 pub(super) fn from_options(
89 opts: HashMap<&str, &str>,
90 nonce_tcp_required: bool,
91 ) -> Result<Self> {
92 let bind = None;
93 if opts.contains_key("bind") {
94 return Err(Error::Address("`bind` isn't yet supported".into()));
95 }
96
97 let host = opts
98 .get("host")
99 .ok_or_else(|| Error::Address("tcp address is missing `host`".into()))?
100 .to_string();
101 let port = opts
102 .get("port")
103 .ok_or_else(|| Error::Address("tcp address is missing `port`".into()))?;
104 let port = port
105 .parse::<u16>()
106 .map_err(|_| Error::Address("invalid tcp `port`".into()))?;
107 let family = opts
108 .get("family")
109 .map(|f| TcpTransportFamily::from_str(f))
110 .transpose()?;
111 let nonce_file = opts
112 .get("noncefile")
113 .map(|f| super::decode_percents(f))
114 .transpose()?;
115 if nonce_tcp_required && nonce_file.is_none() {
116 return Err(Error::Address(
117 "nonce-tcp address is missing `noncefile`".into(),
118 ));
119 }
120
121 Ok(Self {
122 host,
123 bind,
124 port,
125 family,
126 nonce_file,
127 })
128 }
129
130 #[cfg(not(feature = "tokio"))]
131 pub(super) async fn connect(self) -> Result<Async<TcpStream>> {
132 let addrs = crate::Task::spawn_blocking(
133 move || -> Result<Vec<SocketAddr>> {
134 let addrs = (self.host(), self.port()).to_socket_addrs()?.filter(|a| {
135 if let Some(family) = self.family() {
136 if family == TcpTransportFamily::Ipv4 {
137 a.is_ipv4()
138 } else {
139 a.is_ipv6()
140 }
141 } else {
142 true
143 }
144 });
145 Ok(addrs.collect())
146 },
147 "connect tcp",
148 )
149 .await
150 .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}")))?;
151
152 let mut last_err = Error::Address("Failed to connect".into());
154 for addr in addrs {
155 match Async::<TcpStream>::connect(addr).await {
156 Ok(stream) => return Ok(stream),
157 Err(e) => last_err = e.into(),
158 }
159 }
160
161 Err(last_err)
162 }
163
164 #[cfg(feature = "tokio")]
165 pub(super) async fn connect(self) -> Result<TcpStream> {
166 TcpStream::connect((self.host(), self.port()))
167 .await
168 .map_err(|e| Error::InputOutput(e.into()))
169 }
170}
171
172impl Display for Tcp {
173 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
174 match self.nonce_file() {
175 Some(nonce_file) => {
176 f.write_str("nonce-tcp:noncefile=")?;
177 encode_percents(f, nonce_file)?;
178 f.write_str(",")?;
179 }
180 None => f.write_str("tcp:")?,
181 }
182 f.write_str("host=")?;
183
184 encode_percents(f, self.host().as_bytes())?;
185
186 write!(f, ",port={}", self.port())?;
187
188 if let Some(bind) = self.bind() {
189 f.write_str(",bind=")?;
190 encode_percents(f, bind.as_bytes())?;
191 }
192
193 if let Some(family) = self.family() {
194 write!(f, ",family={family}")?;
195 }
196
197 Ok(())
198 }
199}
200
201#[derive(Copy, Clone, Debug, PartialEq, Eq)]
203pub enum TcpTransportFamily {
204 Ipv4,
205 Ipv6,
206}
207
208impl FromStr for TcpTransportFamily {
209 type Err = Error;
210
211 fn from_str(family: &str) -> Result<Self> {
212 match family {
213 "ipv4" => Ok(Self::Ipv4),
214 "ipv6" => Ok(Self::Ipv6),
215 _ => Err(Error::Address(format!(
216 "invalid tcp address `family`: {family}"
217 ))),
218 }
219 }
220}
221
222impl Display for TcpTransportFamily {
223 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
224 match self {
225 Self::Ipv4 => write!(f, "ipv4"),
226 Self::Ipv6 => write!(f, "ipv6"),
227 }
228 }
229}