zbus/connection/socket/
mod.rs1#[cfg(feature = "p2p")]
2pub mod channel;
3#[cfg(feature = "p2p")]
4pub use channel::Channel;
5
6mod split;
7pub use split::{BoxedSplit, Split};
8
9#[cfg(unix)]
10pub(crate) mod command;
11#[cfg(unix)]
12pub(crate) use command::Command;
13mod tcp;
14mod unix;
15mod vsock;
16
17#[cfg(not(feature = "tokio"))]
18use async_io::Async;
19#[cfg(not(feature = "tokio"))]
20use std::sync::Arc;
21use std::{io, mem};
22use tracing::trace;
23
24use crate::{
25 conn::AuthMechanism,
26 fdo::ConnectionCredentials,
27 message::{
28 header::{MAX_MESSAGE_SIZE, MIN_MESSAGE_SIZE},
29 PrimaryHeader,
30 },
31 padding_for_8_bytes, Message,
32};
33#[cfg(unix)]
34use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
35use zvariant::{
36 serialized::{self, Context},
37 Endian,
38};
39
40#[cfg(unix)]
41type RecvmsgResult = io::Result<(usize, Vec<OwnedFd>)>;
42
43#[cfg(not(unix))]
44type RecvmsgResult = io::Result<usize>;
45
46pub trait Socket {
60 type ReadHalf: ReadHalf;
61 type WriteHalf: WriteHalf;
62
63 fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf>
65 where
66 Self: Sized;
67}
68
69#[async_trait::async_trait]
73pub trait ReadHalf: std::fmt::Debug + Send + Sync + 'static {
74 async fn receive_message(
93 &mut self,
94 seq: u64,
95 already_received_bytes: &mut Vec<u8>,
96 #[cfg(unix)] already_received_fds: &mut Vec<OwnedFd>,
97 ) -> crate::Result<Message> {
98 #[cfg(unix)]
99 let mut fds = vec![];
100 let mut bytes = if already_received_bytes.len() < MIN_MESSAGE_SIZE {
101 let mut bytes = vec![];
102 if !already_received_bytes.is_empty() {
103 mem::swap(already_received_bytes, &mut bytes);
104 }
105 let mut pos = bytes.len();
106 bytes.resize(MIN_MESSAGE_SIZE, 0);
107 while pos < MIN_MESSAGE_SIZE {
114 let res = self.recvmsg(&mut bytes[pos..]).await?;
115 let len = {
116 #[cfg(unix)]
117 {
118 fds.extend(res.1);
119 res.0
120 }
121 #[cfg(not(unix))]
122 {
123 res
124 }
125 };
126 pos += len;
127 if len == 0 {
128 return Err(std::io::Error::new(
129 std::io::ErrorKind::UnexpectedEof,
130 "failed to receive message",
131 )
132 .into());
133 }
134 }
135
136 bytes
137 } else {
138 already_received_bytes.drain(..MIN_MESSAGE_SIZE).collect()
139 };
140
141 let (primary_header, fields_len) = PrimaryHeader::read(&bytes)?;
142 let header_len = MIN_MESSAGE_SIZE + fields_len as usize;
143 let body_padding = padding_for_8_bytes(header_len);
144 let body_len = primary_header.body_len() as usize;
145 let total_len = header_len + body_padding + body_len;
146 if total_len > MAX_MESSAGE_SIZE {
147 return Err(crate::Error::ExcessData);
148 }
149
150 if !already_received_bytes.is_empty() {
153 let pending = total_len - bytes.len();
155 let to_take = std::cmp::min(pending, already_received_bytes.len());
156 bytes.extend(already_received_bytes.drain(..to_take));
157 }
158 let mut pos = bytes.len();
159 bytes.resize(total_len, 0);
160
161 while pos < total_len {
163 let res = self.recvmsg(&mut bytes[pos..]).await?;
164 let read = {
165 #[cfg(unix)]
166 {
167 fds.extend(res.1);
168 res.0
169 }
170 #[cfg(not(unix))]
171 {
172 res
173 }
174 };
175 pos += read;
176 if read == 0 {
177 return Err(crate::Error::InputOutput(
178 std::io::Error::new(
179 std::io::ErrorKind::UnexpectedEof,
180 "failed to receive message",
181 )
182 .into(),
183 ));
184 }
185 }
186
187 let endian = Endian::from(primary_header.endian_sig());
189
190 #[cfg(unix)]
191 if !already_received_fds.is_empty() {
192 use crate::message::header::PRIMARY_HEADER_SIZE;
193
194 let ctxt = Context::new_dbus(endian, PRIMARY_HEADER_SIZE);
195 let encoded_fields =
196 serialized::Data::new(&bytes[PRIMARY_HEADER_SIZE..header_len], ctxt);
197 let fields: crate::message::Fields<'_> = encoded_fields.deserialize()?.0;
198 let num_required_fds = match fields.unix_fds {
199 Some(num_fds) => num_fds as usize,
200 _ => 0,
201 };
202 let num_pending = num_required_fds
203 .checked_sub(fds.len())
204 .ok_or_else(|| crate::Error::ExcessData)?;
205 if num_pending == 0 {
207 return Err(crate::Error::MissingParameter("Missing file descriptors"));
208 }
209 let mut already_received: Vec<_> = already_received_fds.drain(..num_pending).collect();
211 mem::swap(&mut already_received, &mut fds);
212 fds.extend(already_received);
213 }
214
215 let ctxt = Context::new_dbus(endian, 0);
216 #[cfg(unix)]
217 let bytes = serialized::Data::new_fds(bytes, ctxt, fds);
218 #[cfg(not(unix))]
219 let bytes = serialized::Data::new(bytes, ctxt);
220 Message::from_raw_parts(bytes, seq)
221 }
222
223 async fn recvmsg(&mut self, _buf: &mut [u8]) -> RecvmsgResult {
231 unimplemented!("`ReadHalf` implementers must either override `read_message` or `recvmsg`");
232 }
233
234 fn can_pass_unix_fd(&self) -> bool {
238 false
239 }
240
241 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
243 Ok(ConnectionCredentials::default())
244 }
245
246 fn auth_mechanism(&self) -> AuthMechanism {
250 AuthMechanism::External
251 }
252}
253
254#[async_trait::async_trait]
258pub trait WriteHalf: std::fmt::Debug + Send + Sync + 'static {
259 async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
266 let data = msg.data();
267 let serial = msg.primary_header().serial_num();
268
269 trace!("Sending message: {:?}", msg);
270 let mut pos = 0;
271 while pos < data.len() {
272 #[cfg(unix)]
273 let fds = if pos == 0 {
274 data.fds().iter().map(|f| f.as_fd()).collect()
275 } else {
276 vec![]
277 };
278 pos += self
279 .sendmsg(
280 &data[pos..],
281 #[cfg(unix)]
282 &fds,
283 )
284 .await?;
285 }
286 trace!("Sent message with serial: {}", serial);
287
288 Ok(())
289 }
290
291 async fn sendmsg(
306 &mut self,
307 _buffer: &[u8],
308 #[cfg(unix)] _fds: &[BorrowedFd<'_>],
309 ) -> io::Result<usize> {
310 unimplemented!("`WriteHalf` implementers must either override `send_message` or `sendmsg`");
311 }
312
313 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
318 async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
319 Ok(None)
320 }
321
322 async fn close(&mut self) -> io::Result<()>;
326
327 fn can_pass_unix_fd(&self) -> bool {
331 false
332 }
333
334 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
336 Ok(ConnectionCredentials::default())
337 }
338}
339
340#[async_trait::async_trait]
341impl ReadHalf for Box<dyn ReadHalf> {
342 fn can_pass_unix_fd(&self) -> bool {
343 (**self).can_pass_unix_fd()
344 }
345
346 async fn receive_message(
347 &mut self,
348 seq: u64,
349 already_received_bytes: &mut Vec<u8>,
350 #[cfg(unix)] already_received_fds: &mut Vec<OwnedFd>,
351 ) -> crate::Result<Message> {
352 (**self)
353 .receive_message(
354 seq,
355 already_received_bytes,
356 #[cfg(unix)]
357 already_received_fds,
358 )
359 .await
360 }
361
362 async fn recvmsg(&mut self, buf: &mut [u8]) -> RecvmsgResult {
363 (**self).recvmsg(buf).await
364 }
365
366 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
367 (**self).peer_credentials().await
368 }
369
370 fn auth_mechanism(&self) -> AuthMechanism {
371 (**self).auth_mechanism()
372 }
373}
374
375#[async_trait::async_trait]
376impl WriteHalf for Box<dyn WriteHalf> {
377 async fn send_message(&mut self, msg: &Message) -> crate::Result<()> {
378 (**self).send_message(msg).await
379 }
380
381 async fn sendmsg(
382 &mut self,
383 buffer: &[u8],
384 #[cfg(unix)] fds: &[BorrowedFd<'_>],
385 ) -> io::Result<usize> {
386 (**self)
387 .sendmsg(
388 buffer,
389 #[cfg(unix)]
390 fds,
391 )
392 .await
393 }
394
395 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
396 async fn send_zero_byte(&mut self) -> io::Result<Option<usize>> {
397 (**self).send_zero_byte().await
398 }
399
400 async fn close(&mut self) -> io::Result<()> {
401 (**self).close().await
402 }
403
404 fn can_pass_unix_fd(&self) -> bool {
405 (**self).can_pass_unix_fd()
406 }
407
408 async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
409 (**self).peer_credentials().await
410 }
411}
412
413#[cfg(not(feature = "tokio"))]
414impl<T> Socket for Async<T>
415where
416 T: std::fmt::Debug + Send + Sync,
417 Arc<Async<T>>: ReadHalf + WriteHalf,
418{
419 type ReadHalf = Arc<Async<T>>;
420 type WriteHalf = Arc<Async<T>>;
421
422 fn split(self) -> Split<Self::ReadHalf, Self::WriteHalf> {
423 let arc = Arc::new(self);
424
425 Split {
426 read: arc.clone(),
427 write: arc,
428 }
429 }
430}