1use async_broadcast::{broadcast, InactiveReceiver, Receiver, Sender as Broadcaster};
3use enumflags2::BitFlags;
4use event_listener::{Event, EventListener};
5use ordered_stream::{OrderedFuture, OrderedStream, PollResult};
6use std::{
7 collections::HashMap,
8 io::{self, ErrorKind},
9 num::NonZeroU32,
10 pin::Pin,
11 sync::{Arc, OnceLock, Weak},
12 task::{Context, Poll},
13 time::Duration,
14};
15use tracing::{debug, info_span, instrument, trace, trace_span, warn, Instrument};
16use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, OwnedUniqueName, WellKnownName};
17use zvariant::ObjectPath;
18
19use futures_core::Future;
20use futures_lite::StreamExt;
21
22use crate::{
23 async_lock::{Mutex, Semaphore, SemaphorePermit},
24 fdo::{ConnectionCredentials, ReleaseNameReply, RequestNameFlags, RequestNameReply},
25 is_flatpak,
26 message::{Flags, Message, Type},
27 timeout::timeout,
28 DBusError, Error, Executor, MatchRule, MessageStream, ObjectServer, OwnedGuid, OwnedMatchRule,
29 Result, Task,
30};
31
32mod builder;
33pub use builder::Builder;
34
35pub mod socket;
36pub use socket::Socket;
37
38mod socket_reader;
39use socket_reader::SocketReader;
40
41pub(crate) mod handshake;
42pub use handshake::AuthMechanism;
43use handshake::Authenticated;
44
45const DEFAULT_MAX_QUEUED: usize = 64;
46const DEFAULT_MAX_METHOD_RETURN_QUEUED: usize = 8;
47
48#[derive(Debug)]
50pub(crate) struct ConnectionInner {
51 server_guid: OwnedGuid,
52 #[cfg(unix)]
53 cap_unix_fd: bool,
54 #[cfg(feature = "p2p")]
55 bus_conn: bool,
56 unique_name: OnceLock<OwnedUniqueName>,
57 registered_names: Mutex<HashMap<WellKnownName<'static>, NameStatus>>,
58
59 activity_event: Arc<Event>,
60 socket_write: Mutex<Box<dyn socket::WriteHalf>>,
61
62 executor: Executor<'static>,
64
65 #[allow(unused)]
67 socket_reader_task: OnceLock<Task<()>>,
68
69 pub(crate) msg_receiver: InactiveReceiver<Result<Message>>,
70 pub(crate) method_return_receiver: InactiveReceiver<Result<Message>>,
71 msg_senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
72
73 subscriptions: Mutex<Subscriptions>,
74
75 object_server: OnceLock<ObjectServer>,
76 object_server_dispatch_task: OnceLock<Task<()>>,
77
78 drop_event: Event,
79
80 method_timeout: Option<Duration>,
81}
82
83impl Drop for ConnectionInner {
84 fn drop(&mut self) {
85 self.drop_event.notify(usize::MAX);
89 }
90}
91
92type Subscriptions = HashMap<OwnedMatchRule, (u64, InactiveReceiver<Result<Message>>)>;
93
94pub(crate) type MsgBroadcaster = Broadcaster<Result<Message>>;
95
96#[derive(Clone, Debug)]
211#[must_use = "Dropping a `Connection` will close the underlying socket."]
212pub struct Connection {
213 pub(crate) inner: Arc<ConnectionInner>,
214}
215
216#[derive(Debug)]
222pub(crate) struct PendingMethodCall {
223 stream: Option<MessageStream>,
224 serial: NonZeroU32,
225}
226
227impl Future for PendingMethodCall {
228 type Output = Result<Message>;
229
230 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
231 self.poll_before(cx, None).map(|ret| {
232 ret.map(|(_, r)| r).unwrap_or_else(|| {
233 Err(crate::Error::InputOutput(
234 io::Error::new(ErrorKind::BrokenPipe, "socket closed").into(),
235 ))
236 })
237 })
238 }
239}
240
241impl OrderedFuture for PendingMethodCall {
242 type Output = Result<Message>;
243 type Ordering = zbus::message::Sequence;
244
245 fn poll_before(
246 self: Pin<&mut Self>,
247 cx: &mut Context<'_>,
248 before: Option<&Self::Ordering>,
249 ) -> Poll<Option<(Self::Ordering, Self::Output)>> {
250 let this = self.get_mut();
251 if let Some(stream) = &mut this.stream {
252 loop {
253 match Pin::new(&mut *stream).poll_next_before(cx, before) {
254 Poll::Ready(PollResult::Item {
255 data: Ok(msg),
256 ordering,
257 }) => {
258 if msg.header().reply_serial() != Some(this.serial) {
259 continue;
260 }
261 let res = match msg.message_type() {
262 Type::Error => Err(msg.into()),
263 Type::MethodReturn => Ok(msg),
264 _ => continue,
265 };
266 this.stream = None;
267 return Poll::Ready(Some((ordering, res)));
268 }
269 Poll::Ready(PollResult::Item {
270 data: Err(e),
271 ordering,
272 }) => {
273 return Poll::Ready(Some((ordering, Err(e))));
274 }
275
276 Poll::Ready(PollResult::NoneBefore) => {
277 return Poll::Ready(None);
278 }
279 Poll::Ready(PollResult::Terminated) => {
280 return Poll::Ready(None);
281 }
282 Poll::Pending => return Poll::Pending,
283 }
284 }
285 }
286 Poll::Ready(None)
287 }
288}
289
290impl Connection {
291 pub async fn send(&self, msg: &Message) -> Result<()> {
293 #[cfg(unix)]
294 if !msg.data().fds().is_empty() && !self.inner.cap_unix_fd {
295 return Err(Error::Unsupported);
296 }
297
298 self.inner.activity_event.notify(usize::MAX);
299 let mut write = self.inner.socket_write.lock().await;
300
301 write.send_message(msg).await
302 }
303
304 pub async fn call_method<'d, 'p, 'i, 'm, D, P, I, M, B>(
311 &self,
312 destination: Option<D>,
313 path: P,
314 interface: Option<I>,
315 method_name: M,
316 body: &B,
317 ) -> Result<Message>
318 where
319 D: TryInto<BusName<'d>>,
320 P: TryInto<ObjectPath<'p>>,
321 I: TryInto<InterfaceName<'i>>,
322 M: TryInto<MemberName<'m>>,
323 D::Error: Into<Error>,
324 P::Error: Into<Error>,
325 I::Error: Into<Error>,
326 M::Error: Into<Error>,
327 B: serde::ser::Serialize + zvariant::DynamicType,
328 {
329 let method = self
330 .call_method_raw(
331 destination,
332 path,
333 interface,
334 method_name,
335 BitFlags::empty(),
336 body,
337 )
338 .await?
339 .expect("no reply");
340
341 if let Some(tout) = self.method_timeout() {
342 timeout(method, tout).await
343 } else {
344 method.await
345 }
346 }
347
348 pub(crate) async fn call_method_raw<'d, 'p, 'i, 'm, D, P, I, M, B>(
359 &self,
360 destination: Option<D>,
361 path: P,
362 interface: Option<I>,
363 method_name: M,
364 flags: BitFlags<Flags>,
365 body: &B,
366 ) -> Result<Option<PendingMethodCall>>
367 where
368 D: TryInto<BusName<'d>>,
369 P: TryInto<ObjectPath<'p>>,
370 I: TryInto<InterfaceName<'i>>,
371 M: TryInto<MemberName<'m>>,
372 D::Error: Into<Error>,
373 P::Error: Into<Error>,
374 I::Error: Into<Error>,
375 M::Error: Into<Error>,
376 B: serde::ser::Serialize + zvariant::DynamicType,
377 {
378 let _permit = acquire_serial_num_semaphore().await;
379
380 let mut builder = Message::method_call(path, method_name)?;
381 if let Some(sender) = self.unique_name() {
382 builder = builder.sender(sender)?
383 }
384 if let Some(destination) = destination {
385 builder = builder.destination(destination)?
386 }
387 if let Some(interface) = interface {
388 builder = builder.interface(interface)?
389 }
390 for flag in flags {
391 builder = builder.with_flags(flag)?;
392 }
393 let msg = builder.build(body)?;
394
395 let msg_receiver = self.inner.method_return_receiver.activate_cloned();
396 let stream = Some(MessageStream::for_subscription_channel(
397 msg_receiver,
398 None,
400 self,
401 ));
402 let serial = msg.primary_header().serial_num();
403 self.send(&msg).await?;
404 if flags.contains(Flags::NoReplyExpected) {
405 Ok(None)
406 } else {
407 Ok(Some(PendingMethodCall { stream, serial }))
408 }
409 }
410
411 pub async fn emit_signal<'d, 'p, 'i, 'm, D, P, I, M, B>(
415 &self,
416 destination: Option<D>,
417 path: P,
418 interface: I,
419 signal_name: M,
420 body: &B,
421 ) -> Result<()>
422 where
423 D: TryInto<BusName<'d>>,
424 P: TryInto<ObjectPath<'p>>,
425 I: TryInto<InterfaceName<'i>>,
426 M: TryInto<MemberName<'m>>,
427 D::Error: Into<Error>,
428 P::Error: Into<Error>,
429 I::Error: Into<Error>,
430 M::Error: Into<Error>,
431 B: serde::ser::Serialize + zvariant::DynamicType,
432 {
433 let _permit = acquire_serial_num_semaphore().await;
434
435 let mut b = Message::signal(path, interface, signal_name)?;
436 if let Some(sender) = self.unique_name() {
437 b = b.sender(sender)?;
438 }
439 if let Some(destination) = destination {
440 b = b.destination(destination)?;
441 }
442 let m = b.build(body)?;
443
444 self.send(&m).await
445 }
446
447 pub async fn reply<B>(&self, call: &zbus::message::Header<'_>, body: &B) -> Result<()>
452 where
453 B: serde::ser::Serialize + zvariant::DynamicType,
454 {
455 let _permit = acquire_serial_num_semaphore().await;
456
457 let mut b = Message::method_return(call)?;
458 if let Some(sender) = self.unique_name() {
459 b = b.sender(sender)?;
460 }
461 let m = b.build(body)?;
462 self.send(&m).await
463 }
464
465 pub async fn reply_error<'e, E, B>(
470 &self,
471 call: &zbus::message::Header<'_>,
472 error_name: E,
473 body: &B,
474 ) -> Result<()>
475 where
476 B: serde::ser::Serialize + zvariant::DynamicType,
477 E: TryInto<ErrorName<'e>>,
478 E::Error: Into<Error>,
479 {
480 let _permit = acquire_serial_num_semaphore().await;
481
482 let mut b = Message::error(call, error_name)?;
483 if let Some(sender) = self.unique_name() {
484 b = b.sender(sender)?;
485 }
486 let m = b.build(body)?;
487 self.send(&m).await
488 }
489
490 pub async fn reply_dbus_error(
495 &self,
496 call: &zbus::message::Header<'_>,
497 err: impl DBusError,
498 ) -> Result<()> {
499 let _permit = acquire_serial_num_semaphore().await;
500
501 let m = err.create_reply(call)?;
502 self.send(&m).await
503 }
504
505 pub async fn request_name<'w, W>(&self, well_known_name: W) -> Result<()>
535 where
536 W: TryInto<WellKnownName<'w>>,
537 W::Error: Into<Error>,
538 {
539 self.request_name_with_flags(well_known_name, BitFlags::default())
540 .await
541 .map(|_| ())
542 }
543
544 pub async fn request_name_with_flags<'w, W>(
618 &self,
619 well_known_name: W,
620 flags: BitFlags<RequestNameFlags>,
621 ) -> Result<RequestNameReply>
622 where
623 W: TryInto<WellKnownName<'w>>,
624 W::Error: Into<Error>,
625 {
626 let well_known_name = well_known_name.try_into().map_err(Into::into)?;
627 let mut names = self.inner.registered_names.lock().await;
630
631 match names.get(&well_known_name) {
632 Some(NameStatus::Owner(_)) => return Ok(RequestNameReply::AlreadyOwner),
633 Some(NameStatus::Queued(_)) => return Ok(RequestNameReply::InQueue),
634 None => (),
635 }
636
637 if !self.is_bus() {
638 names.insert(well_known_name.to_owned(), NameStatus::Owner(None));
639
640 return Ok(RequestNameReply::PrimaryOwner);
641 }
642
643 let acquired_match_rule = MatchRule::fdo_signal_builder("NameAcquired")
644 .arg(0, well_known_name.as_ref())
645 .unwrap()
646 .build();
647 let mut acquired_stream = self.add_match(acquired_match_rule.into(), None).await?;
648 let lost_match_rule = MatchRule::fdo_signal_builder("NameLost")
649 .arg(0, well_known_name.as_ref())
650 .unwrap()
651 .build();
652 let mut lost_stream = self.add_match(lost_match_rule.into(), None).await?;
653 let reply = self
654 .call_method(
655 Some("org.freedesktop.DBus"),
656 "/org/freedesktop/DBus",
657 Some("org.freedesktop.DBus"),
658 "RequestName",
659 &(well_known_name.clone(), flags),
660 )
661 .await?
662 .body()
663 .deserialize::<RequestNameReply>()?;
664 let lost_task_name = format!("monitor_name_lost{{name={well_known_name}}}");
665 let lost_task_name_span = info_span!("monitor_name_lost", name = %well_known_name);
666 let name_lost_fut = if flags.contains(RequestNameFlags::AllowReplacement) {
667 let weak_conn = WeakConnection::from(self);
668 let well_known_name = well_known_name.to_owned();
669 Some(
670 async move {
671 loop {
672 let signal = lost_stream.next().await;
673 let inner = match weak_conn.upgrade() {
674 Some(conn) => conn.inner.clone(),
675 None => break,
676 };
677
678 match signal {
679 Some(signal) => match signal {
680 Ok(_) => {
681 tracing::info!(
682 "Connection `{}` lost name `{}`",
683 inner.unique_name.get().unwrap(),
686 well_known_name
687 );
688 inner.registered_names.lock().await.remove(&well_known_name);
689
690 break;
691 }
692 Err(e) => warn!("Failed to parse `NameLost` signal: {}", e),
693 },
694 None => {
695 trace!("`NameLost` signal stream closed");
696 break;
705 }
706 }
707 }
708 }
709 .instrument(lost_task_name_span),
710 )
711 } else {
712 None
713 };
714 let status = match reply {
715 RequestNameReply::InQueue => {
716 let weak_conn = WeakConnection::from(self);
717 let well_known_name = well_known_name.to_owned();
718 let task_name = format!("monitor_name_acquired{{name={well_known_name}}}");
719 let task_name_span = info_span!("monitor_name_acquired", name = %well_known_name);
720 let task = self.executor().spawn(
721 async move {
722 loop {
723 let signal = acquired_stream.next().await;
724 let inner = match weak_conn.upgrade() {
725 Some(conn) => conn.inner.clone(),
726 None => break,
727 };
728 match signal {
729 Some(signal) => match signal {
730 Ok(_) => {
731 let mut names = inner.registered_names.lock().await;
732 if let Some(status) = names.get_mut(&well_known_name) {
733 let task = name_lost_fut.map(|fut| {
734 inner.executor.spawn(fut, &lost_task_name)
735 });
736 *status = NameStatus::Owner(task);
737
738 break;
739 }
740 }
742 Err(e) => warn!("Failed to parse `NameAcquired` signal: {}", e),
743 },
744 None => {
745 trace!("`NameAcquired` signal stream closed");
746 break;
749 }
750 }
751 }
752 }
753 .instrument(task_name_span),
754 &task_name,
755 );
756
757 NameStatus::Queued(task)
758 }
759 RequestNameReply::PrimaryOwner | RequestNameReply::AlreadyOwner => {
760 let task = name_lost_fut.map(|fut| self.executor().spawn(fut, &lost_task_name));
761
762 NameStatus::Owner(task)
763 }
764 RequestNameReply::Exists => return Err(Error::NameTaken),
765 };
766
767 names.insert(well_known_name.to_owned(), status);
768
769 Ok(reply)
770 }
771
772 pub async fn release_name<'w, W>(&self, well_known_name: W) -> Result<bool>
781 where
782 W: TryInto<WellKnownName<'w>>,
783 W::Error: Into<Error>,
784 {
785 let well_known_name: WellKnownName<'w> = well_known_name.try_into().map_err(Into::into)?;
786 let mut names = self.inner.registered_names.lock().await;
787 if names.remove(&well_known_name.to_owned()).is_none() {
789 return Ok(false);
790 };
791
792 if !self.is_bus() {
793 return Ok(true);
794 }
795
796 self.call_method(
797 Some("org.freedesktop.DBus"),
798 "/org/freedesktop/DBus",
799 Some("org.freedesktop.DBus"),
800 "ReleaseName",
801 &well_known_name,
802 )
803 .await?
804 .body()
805 .deserialize::<ReleaseNameReply>()
806 .map(|r| r == ReleaseNameReply::Released)
807 }
808
809 pub fn is_bus(&self) -> bool {
814 #[cfg(feature = "p2p")]
815 {
816 self.inner.bus_conn
817 }
818 #[cfg(not(feature = "p2p"))]
819 {
820 true
821 }
822 }
823
824 pub fn unique_name(&self) -> Option<&OwnedUniqueName> {
829 self.inner.unique_name.get()
830 }
831
832 #[cfg(feature = "bus-impl")]
842 pub fn set_unique_name<U>(&self, unique_name: U) -> Result<()>
843 where
844 U: TryInto<OwnedUniqueName>,
845 U::Error: Into<Error>,
846 {
847 let name = unique_name.try_into().map_err(Into::into)?;
848 self.set_unique_name_(name);
849
850 Ok(())
851 }
852
853 pub fn max_queued(&self) -> usize {
855 self.inner.msg_receiver.capacity()
856 }
857
858 pub fn set_max_queued(&mut self, max: usize) {
860 self.inner.msg_receiver.clone().set_capacity(max);
861 }
862
863 pub fn server_guid(&self) -> &OwnedGuid {
865 &self.inner.server_guid
866 }
867
868 pub fn executor(&self) -> &Executor<'static> {
920 &self.inner.executor
921 }
922
923 pub fn object_server(&self) -> &ObjectServer {
931 self.ensure_object_server(true)
932 }
933
934 pub(crate) fn ensure_object_server(&self, start: bool) -> &ObjectServer {
935 self.inner
936 .object_server
937 .get_or_init(move || self.setup_object_server(start, None))
938 }
939
940 fn setup_object_server(&self, start: bool, started_event: Option<Event>) -> ObjectServer {
941 if start {
942 self.start_object_server(started_event);
943 }
944
945 ObjectServer::new(self)
946 }
947
948 #[instrument(skip(self))]
949 pub(crate) fn start_object_server(&self, started_event: Option<Event>) {
950 self.inner.object_server_dispatch_task.get_or_init(|| {
951 trace!("starting ObjectServer task");
952 let weak_conn = WeakConnection::from(self);
953
954 self.inner.executor.spawn(
955 async move {
956 let mut stream = match weak_conn.upgrade() {
957 Some(conn) => {
958 let mut builder = MatchRule::builder().msg_type(Type::MethodCall);
959 if let Some(unique_name) = conn.unique_name() {
960 builder = builder.destination(&**unique_name).expect("unique name");
961 }
962 let rule = builder.build();
963 match conn.add_match(rule.into(), None).await {
964 Ok(stream) => stream,
965 Err(e) => {
966 debug!("Failed to create message stream: {}", e);
968
969 return;
970 }
971 }
972 }
973 None => {
974 trace!("Connection is gone, stopping associated object server task");
975
976 return;
977 }
978 };
979 if let Some(started_event) = started_event {
980 started_event.notify(1);
981 }
982
983 trace!("waiting for incoming method call messages..");
984 while let Some(msg) = stream.next().await.and_then(|m| {
985 if let Err(e) = &m {
986 debug!("Error while reading from object server stream: {:?}", e);
987 }
988 m.ok()
989 }) {
990 if let Some(conn) = weak_conn.upgrade() {
991 let hdr = msg.header();
992 if !conn.is_bus() {
995 match hdr.destination() {
996 Some(BusName::Unique(_)) | None => (),
998 Some(BusName::WellKnown(dest)) => {
999 let names = conn.inner.registered_names.lock().await;
1000 if !names.is_empty() && !names.contains_key(dest) {
1004 trace!(
1005 "Got a method call for a different destination: {}",
1006 dest
1007 );
1008
1009 continue;
1010 }
1011 }
1012 }
1013 }
1014 let server = conn.object_server();
1015 if let Err(e) = server.dispatch_call(&msg, &hdr).await {
1016 debug!(
1017 "Error dispatching message. Message: {:?}, error: {:?}",
1018 msg, e
1019 );
1020 }
1021 } else {
1022 trace!("Connection is gone, stopping associated object server task");
1025 break;
1026 }
1027 }
1028 }
1029 .instrument(info_span!("obj_server_task")),
1030 "obj_server_task",
1031 )
1032 });
1033 }
1034
1035 pub(crate) async fn add_match(
1036 &self,
1037 rule: OwnedMatchRule,
1038 max_queued: Option<usize>,
1039 ) -> Result<Receiver<Result<Message>>> {
1040 use std::collections::hash_map::Entry;
1041
1042 if self.inner.msg_senders.lock().await.is_empty() {
1043 return Err(Error::InputOutput(Arc::new(io::Error::new(
1045 io::ErrorKind::BrokenPipe,
1046 "Socket reader task has errored out",
1047 ))));
1048 }
1049
1050 let mut subscriptions = self.inner.subscriptions.lock().await;
1051 let msg_type = rule.msg_type().unwrap_or(Type::Signal);
1052 match subscriptions.entry(rule.clone()) {
1053 Entry::Vacant(e) => {
1054 let max_queued = max_queued.unwrap_or(DEFAULT_MAX_QUEUED);
1055 let (sender, mut receiver) = broadcast(max_queued);
1056 receiver.set_await_active(false);
1057 if self.is_bus() && msg_type == Type::Signal {
1058 self.call_method(
1059 Some("org.freedesktop.DBus"),
1060 "/org/freedesktop/DBus",
1061 Some("org.freedesktop.DBus"),
1062 "AddMatch",
1063 &e.key(),
1064 )
1065 .await?;
1066 }
1067 e.insert((1, receiver.clone().deactivate()));
1068 self.inner
1069 .msg_senders
1070 .lock()
1071 .await
1072 .insert(Some(rule), sender);
1073
1074 Ok(receiver)
1075 }
1076 Entry::Occupied(mut e) => {
1077 let (num_subscriptions, receiver) = e.get_mut();
1078 *num_subscriptions += 1;
1079 if let Some(max_queued) = max_queued {
1080 if max_queued > receiver.capacity() {
1081 receiver.set_capacity(max_queued);
1082 }
1083 }
1084
1085 Ok(receiver.activate_cloned())
1086 }
1087 }
1088 }
1089
1090 pub(crate) async fn remove_match(&self, rule: OwnedMatchRule) -> Result<bool> {
1091 use std::collections::hash_map::Entry;
1092 let mut subscriptions = self.inner.subscriptions.lock().await;
1093 let msg_type = rule.msg_type().unwrap_or(Type::Signal);
1096 match subscriptions.entry(rule) {
1097 Entry::Vacant(_) => Ok(false),
1098 Entry::Occupied(mut e) => {
1099 let rule = e.key().inner().clone();
1100 e.get_mut().0 -= 1;
1101 if e.get().0 == 0 {
1102 if self.is_bus() && msg_type == Type::Signal {
1103 self.call_method(
1104 Some("org.freedesktop.DBus"),
1105 "/org/freedesktop/DBus",
1106 Some("org.freedesktop.DBus"),
1107 "RemoveMatch",
1108 &rule,
1109 )
1110 .await?;
1111 }
1112 e.remove();
1113 self.inner
1114 .msg_senders
1115 .lock()
1116 .await
1117 .remove(&Some(rule.into()));
1118 }
1119 Ok(true)
1120 }
1121 }
1122 }
1123
1124 pub(crate) fn queue_remove_match(&self, rule: OwnedMatchRule) {
1125 let conn = self.clone();
1126 let task_name = format!("Remove match `{}`", *rule);
1127 let remove_match =
1128 async move { conn.remove_match(rule).await }.instrument(trace_span!("{}", task_name));
1129 self.inner.executor.spawn(remove_match, &task_name).detach()
1130 }
1131
1132 pub fn method_timeout(&self) -> Option<Duration> {
1134 self.inner.method_timeout
1135 }
1136
1137 pub(crate) async fn new(
1138 auth: Authenticated,
1139 #[allow(unused)] bus_connection: bool,
1140 executor: Executor<'static>,
1141 method_timeout: Option<Duration>,
1142 ) -> Result<Self> {
1143 #[cfg(unix)]
1144 let cap_unix_fd = auth.cap_unix_fd;
1145
1146 macro_rules! create_msg_broadcast_channel {
1147 ($size:expr) => {{
1148 let (msg_sender, msg_receiver) = broadcast($size);
1149 let mut msg_receiver = msg_receiver.deactivate();
1150 msg_receiver.set_await_active(false);
1151
1152 (msg_sender, msg_receiver)
1153 }};
1154 }
1155 let (msg_sender, msg_receiver) = create_msg_broadcast_channel!(DEFAULT_MAX_QUEUED);
1157 let mut msg_senders = HashMap::new();
1158 msg_senders.insert(None, msg_sender);
1159
1160 let (method_return_sender, method_return_receiver) =
1162 create_msg_broadcast_channel!(DEFAULT_MAX_METHOD_RETURN_QUEUED);
1163 let rule = MatchRule::builder()
1164 .msg_type(Type::MethodReturn)
1165 .build()
1166 .into();
1167 msg_senders.insert(Some(rule), method_return_sender.clone());
1168 let rule = MatchRule::builder().msg_type(Type::Error).build().into();
1169 msg_senders.insert(Some(rule), method_return_sender);
1170 let msg_senders = Arc::new(Mutex::new(msg_senders));
1171 let subscriptions = Mutex::new(HashMap::new());
1172
1173 let connection = Self {
1174 inner: Arc::new(ConnectionInner {
1175 activity_event: Arc::new(Event::new()),
1176 socket_write: Mutex::new(auth.socket_write),
1177 server_guid: auth.server_guid,
1178 #[cfg(unix)]
1179 cap_unix_fd,
1180 #[cfg(feature = "p2p")]
1181 bus_conn: bus_connection,
1182 unique_name: OnceLock::new(),
1183 subscriptions,
1184 object_server: OnceLock::new(),
1185 object_server_dispatch_task: OnceLock::new(),
1186 executor,
1187 socket_reader_task: OnceLock::new(),
1188 msg_senders,
1189 msg_receiver,
1190 method_return_receiver,
1191 registered_names: Mutex::new(HashMap::new()),
1192 drop_event: Event::new(),
1193 method_timeout,
1194 }),
1195 };
1196
1197 if let Some(unique_name) = auth.unique_name {
1198 connection.set_unique_name_(unique_name);
1199 }
1200
1201 Ok(connection)
1202 }
1203
1204 pub async fn session() -> Result<Self> {
1206 Builder::session()?.build().await
1207 }
1208
1209 pub async fn system() -> Result<Self> {
1211 Builder::system()?.build().await
1212 }
1213
1214 pub fn monitor_activity(&self) -> EventListener {
1218 self.inner.activity_event.listen()
1219 }
1220
1221 pub async fn peer_credentials(&self) -> io::Result<ConnectionCredentials> {
1230 self.inner
1231 .socket_write
1232 .lock()
1233 .await
1234 .peer_credentials()
1235 .await
1236 }
1237
1238 pub async fn close(self) -> Result<()> {
1242 self.inner.activity_event.notify(usize::MAX);
1243 self.inner
1244 .socket_write
1245 .lock()
1246 .await
1247 .close()
1248 .await
1249 .map_err(Into::into)
1250 }
1251
1252 pub async fn graceful_shutdown(self) {
1296 let listener = self.inner.drop_event.listen();
1297 drop(self);
1298 listener.await;
1299 }
1300
1301 pub(crate) fn init_socket_reader(
1302 &self,
1303 socket_read: Box<dyn socket::ReadHalf>,
1304 already_read: Vec<u8>,
1305 #[cfg(unix)] already_received_fds: Vec<std::os::fd::OwnedFd>,
1306 ) {
1307 let inner = &self.inner;
1308 inner
1309 .socket_reader_task
1310 .set(
1311 SocketReader::new(
1312 socket_read,
1313 inner.msg_senders.clone(),
1314 already_read,
1315 #[cfg(unix)]
1316 already_received_fds,
1317 inner.activity_event.clone(),
1318 )
1319 .spawn(&inner.executor),
1320 )
1321 .expect("Attempted to set `socket_reader_task` twice");
1322 }
1323
1324 fn set_unique_name_(&self, name: OwnedUniqueName) {
1325 self.inner
1326 .unique_name
1327 .set(name)
1328 .expect("unique name already set");
1330 }
1331}
1332
1333#[cfg(feature = "blocking-api")]
1334impl From<crate::blocking::Connection> for Connection {
1335 fn from(conn: crate::blocking::Connection) -> Self {
1336 conn.into_inner()
1337 }
1338}
1339
1340#[derive(Debug, Clone)]
1342pub(crate) struct WeakConnection {
1343 inner: Weak<ConnectionInner>,
1344}
1345
1346impl WeakConnection {
1347 pub fn upgrade(&self) -> Option<Connection> {
1349 self.inner.upgrade().map(|inner| Connection { inner })
1350 }
1351}
1352
1353impl From<&Connection> for WeakConnection {
1354 fn from(conn: &Connection) -> Self {
1355 Self {
1356 inner: Arc::downgrade(&conn.inner),
1357 }
1358 }
1359}
1360
1361#[derive(Debug)]
1362enum NameStatus {
1363 Owner(#[allow(unused)] Option<Task<()>>),
1365 Queued(#[allow(unused)] Task<()>),
1367}
1368
1369static SERIAL_NUM_SEMAPHORE: Semaphore = Semaphore::new(1);
1370
1371async fn acquire_serial_num_semaphore() -> Option<SemaphorePermit<'static>> {
1376 if is_flatpak() {
1377 Some(SERIAL_NUM_SEMAPHORE.acquire().await)
1378 } else {
1379 None
1380 }
1381}
1382
1383#[cfg(test)]
1384mod tests {
1385 use super::*;
1386 use crate::fdo::DBusProxy;
1387 use ntest::timeout;
1388 use std::{pin::pin, time::Duration};
1389 use test_log::test;
1390
1391 #[cfg(windows)]
1392 #[test]
1393 fn connect_autolaunch_session_bus() {
1394 let addr =
1395 crate::win32::autolaunch_bus_address().expect("Unable to get session bus address");
1396
1397 crate::block_on(async { addr.connect().await }).expect("Unable to connect to session bus");
1398 }
1399
1400 #[cfg(target_os = "macos")]
1401 #[test]
1402 fn connect_launchd_session_bus() {
1403 use crate::address::{transport::Launchd, Address, Transport};
1404 crate::block_on(async {
1405 let addr = Address::from(Transport::Launchd(Launchd::new(
1406 "DBUS_LAUNCHD_SESSION_BUS_SOCKET",
1407 )));
1408 addr.connect().await
1409 })
1410 .expect("Unable to connect to session bus");
1411 }
1412
1413 #[test]
1414 #[timeout(15000)]
1415 fn disconnect_on_drop() {
1416 crate::utils::block_on(test_disconnect_on_drop());
1419 }
1420
1421 async fn test_disconnect_on_drop() {
1422 #[derive(Default)]
1423 struct MyInterface {}
1424
1425 #[crate::interface(name = "dev.peelz.FooBar.Baz")]
1426 impl MyInterface {
1427 fn do_thing(&self) {}
1428 }
1429 let name = "dev.peelz.foobar";
1430 let connection = Builder::session()
1431 .unwrap()
1432 .name(name)
1433 .unwrap()
1434 .serve_at("/dev/peelz/FooBar", MyInterface::default())
1435 .unwrap()
1436 .build()
1437 .await
1438 .unwrap();
1439
1440 let connection2 = Connection::session().await.unwrap();
1441 let dbus = DBusProxy::new(&connection2).await.unwrap();
1442 let mut stream = dbus
1443 .receive_name_owner_changed_with_args(&[(0, name), (2, "")])
1444 .await
1445 .unwrap();
1446
1447 drop(connection);
1448
1449 stream.next().await.unwrap();
1451
1452 let name_has_owner = dbus.name_has_owner(name.try_into().unwrap()).await.unwrap();
1454 assert!(!name_has_owner);
1455 }
1456
1457 #[tokio::test(start_paused = true)]
1458 #[timeout(15000)]
1459 async fn test_graceful_shutdown() {
1460 let connection = Connection::session().await.unwrap();
1462 let clone = connection.clone();
1463 let mut shutdown = pin!(connection.graceful_shutdown());
1464 tokio::select! {
1467 _ = tokio::time::sleep(Duration::from_secs(u64::MAX)) => {},
1468 _ = &mut shutdown => {
1469 panic!("Graceful shutdown unexpectedly completed");
1470 }
1471 }
1472
1473 drop(clone);
1474 shutdown.await;
1475
1476 struct GracefulInterface {
1478 method_called: Event,
1479 wait_before_return: Option<EventListener>,
1480 announce_done: Event,
1481 }
1482
1483 #[crate::interface(name = "dev.peelz.TestGracefulShutdown")]
1484 impl GracefulInterface {
1485 async fn do_thing(&mut self) {
1486 self.method_called.notify(1);
1487 if let Some(listener) = self.wait_before_return.take() {
1488 listener.await;
1489 }
1490 self.announce_done.notify(1);
1491 }
1492 }
1493
1494 let method_called = Event::new();
1495 let method_called_listener = method_called.listen();
1496
1497 let trigger_return = Event::new();
1498 let wait_before_return = Some(trigger_return.listen());
1499
1500 let announce_done = Event::new();
1501 let done_listener = announce_done.listen();
1502
1503 let interface = GracefulInterface {
1504 method_called,
1505 wait_before_return,
1506 announce_done,
1507 };
1508
1509 let name = "dev.peelz.TestGracefulShutdown";
1510 let obj = "/dev/peelz/TestGracefulShutdown";
1511 let connection = Builder::session()
1512 .unwrap()
1513 .name(name)
1514 .unwrap()
1515 .serve_at(obj, interface)
1516 .unwrap()
1517 .build()
1518 .await
1519 .unwrap();
1520
1521 let client_conn = Connection::session().await.unwrap();
1523 tokio::spawn(async move {
1524 client_conn
1525 .call_method(Some(name), obj, Some(name), "DoThing", &())
1526 .await
1527 .unwrap();
1528 });
1529
1530 method_called_listener.await;
1533
1534 let mut shutdown = pin!(connection.graceful_shutdown());
1535 tokio::select! {
1536 _ = tokio::time::sleep(Duration::from_secs(u64::MAX)) => {},
1537 _ = &mut shutdown => {
1538 panic!("Graceful shutdown unexpectedly completed");
1540 }
1541 }
1542
1543 trigger_return.notify(1);
1545 shutdown.await;
1546
1547 done_listener.await;
1549 }
1550}
1551
1552#[cfg(feature = "p2p")]
1553#[cfg(test)]
1554mod p2p_tests {
1555 use event_listener::Event;
1556 use futures_util::TryStreamExt;
1557 use ntest::timeout;
1558 use test_log::test;
1559 use zvariant::{Endian, NATIVE_ENDIAN};
1560
1561 use super::{socket, Builder, Connection};
1562 use crate::{conn::AuthMechanism, Guid, Message, MessageStream, Result};
1563
1564 async fn test_p2p(
1566 server1: Connection,
1567 client1: Connection,
1568 server2: Connection,
1569 client2: Connection,
1570 ) -> Result<()> {
1571 let forward1 = {
1572 let stream = MessageStream::from(server1.clone());
1573 let sink = client2.clone();
1574
1575 stream.try_for_each(move |msg| {
1576 let sink = sink.clone();
1577 async move { sink.send(&msg).await }
1578 })
1579 };
1580 let forward2 = {
1581 let stream = MessageStream::from(client2.clone());
1582 let sink = server1.clone();
1583
1584 stream.try_for_each(move |msg| {
1585 let sink = sink.clone();
1586 async move { sink.send(&msg).await }
1587 })
1588 };
1589 let _forward_task = client1.executor().spawn(
1590 async move { futures_util::try_join!(forward1, forward2) },
1591 "forward_task",
1592 );
1593
1594 let server_ready = Event::new();
1595 let server_ready_listener = server_ready.listen();
1596 let client_done = Event::new();
1597 let client_done_listener = client_done.listen();
1598
1599 let server_future = async move {
1600 let mut stream = MessageStream::from(&server2);
1601 server_ready.notify(1);
1602 let method = loop {
1603 let m = stream.try_next().await?.unwrap();
1604 if m.to_string() == "Method call Test" {
1605 assert_eq!(m.body().deserialize::<u64>().unwrap(), 64);
1606 break m;
1607 }
1608 };
1609
1610 server2
1612 .emit_signal(None::<()>, "/", "org.zbus.p2p", "ASignalForYou", &())
1613 .await?;
1614 server2.reply(&method.header(), &("yay")).await?;
1615 client_done_listener.await;
1616
1617 Ok(())
1618 };
1619
1620 let client_future = async move {
1621 let mut stream = MessageStream::from(&client1);
1622 server_ready_listener.await;
1623 let endian = match NATIVE_ENDIAN {
1627 Endian::Little => Endian::Big,
1628 Endian::Big => Endian::Little,
1629 };
1630 let method = Message::method_call("/", "Test")?
1631 .interface("org.zbus.p2p")?
1632 .endian(endian)
1633 .build(&64u64)?;
1634 client1.send(&method).await?;
1635 let m = stream.try_next().await?.unwrap();
1637 client_done.notify(1);
1638 assert_eq!(m.to_string(), "Signal ASignalForYou");
1639 let reply = stream.try_next().await?.unwrap();
1640 assert_eq!(reply.to_string(), "Method return");
1641 assert_eq!(Endian::from(reply.primary_header().endian_sig()), endian);
1643 reply.body().deserialize::<String>()
1644 };
1645
1646 let (val, _) = futures_util::try_join!(client_future, server_future,)?;
1647 assert_eq!(val, "yay");
1648
1649 Ok(())
1650 }
1651
1652 #[test]
1653 #[timeout(15000)]
1654 fn tcp_p2p() {
1655 crate::utils::block_on(test_tcp_p2p()).unwrap();
1656 }
1657
1658 async fn test_tcp_p2p() -> Result<()> {
1659 let (server1, client1) = tcp_p2p_pipe().await?;
1660 let (server2, client2) = tcp_p2p_pipe().await?;
1661
1662 test_p2p(server1, client1, server2, client2).await
1663 }
1664
1665 async fn tcp_p2p_pipe() -> Result<(Connection, Connection)> {
1666 let guid = Guid::generate();
1667
1668 #[cfg(not(feature = "tokio"))]
1669 let (server_conn_builder, client_conn_builder) = {
1670 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
1671 let addr = listener.local_addr().unwrap();
1672 let p1 = std::net::TcpStream::connect(addr).unwrap();
1673 let p0 = listener.incoming().next().unwrap().unwrap();
1674
1675 (
1676 Builder::tcp_stream(p0)
1677 .server(guid)
1678 .unwrap()
1679 .p2p()
1680 .auth_mechanism(AuthMechanism::Anonymous),
1681 Builder::tcp_stream(p1).p2p(),
1682 )
1683 };
1684
1685 #[cfg(feature = "tokio")]
1686 let (server_conn_builder, client_conn_builder) = {
1687 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1688 let addr = listener.local_addr().unwrap();
1689 let p1 = tokio::net::TcpStream::connect(addr).await.unwrap();
1690 let p0 = listener.accept().await.unwrap().0;
1691
1692 (
1693 Builder::tcp_stream(p0)
1694 .server(guid)
1695 .unwrap()
1696 .p2p()
1697 .auth_mechanism(AuthMechanism::Anonymous),
1698 Builder::tcp_stream(p1).p2p(),
1699 )
1700 };
1701
1702 futures_util::try_join!(server_conn_builder.build(), client_conn_builder.build())
1703 }
1704
1705 #[cfg(unix)]
1706 #[test]
1707 #[timeout(15000)]
1708 fn unix_p2p() {
1709 crate::utils::block_on(test_unix_p2p()).unwrap();
1710 }
1711
1712 #[cfg(unix)]
1713 async fn test_unix_p2p() -> Result<()> {
1714 let (server1, client1) = unix_p2p_pipe().await?;
1715 let (server2, client2) = unix_p2p_pipe().await?;
1716
1717 test_p2p(server1, client1, server2, client2).await
1718 }
1719
1720 #[cfg(unix)]
1721 async fn unix_p2p_pipe() -> Result<(Connection, Connection)> {
1722 #[cfg(not(feature = "tokio"))]
1723 use std::os::unix::net::UnixStream;
1724 #[cfg(feature = "tokio")]
1725 use tokio::net::UnixStream;
1726 #[cfg(all(windows, not(feature = "tokio")))]
1727 use uds_windows::UnixStream;
1728
1729 let guid = Guid::generate();
1730
1731 let (p0, p1) = UnixStream::pair().unwrap();
1732
1733 futures_util::try_join!(
1734 Builder::unix_stream(p1).p2p().build(),
1735 Builder::unix_stream(p0).server(guid).unwrap().p2p().build(),
1736 )
1737 }
1738
1739 #[cfg(any(
1740 all(feature = "vsock", not(feature = "tokio")),
1741 feature = "tokio-vsock"
1742 ))]
1743 #[test]
1744 #[timeout(15000)]
1745 fn vsock_connect() {
1746 let _ = crate::utils::block_on(test_vsock_connect()).unwrap();
1747 }
1748
1749 #[cfg(any(
1750 all(feature = "vsock", not(feature = "tokio")),
1751 feature = "tokio-vsock"
1752 ))]
1753 async fn test_vsock_connect() -> Result<(Connection, Connection)> {
1754 #[cfg(feature = "tokio-vsock")]
1755 use futures_util::StreamExt;
1756
1757 let guid = Guid::generate();
1758
1759 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1760 let listener = vsock::VsockListener::bind_with_cid_port(vsock::VMADDR_CID_LOCAL, u32::MAX)?;
1761 #[cfg(feature = "tokio-vsock")]
1762 let listener = tokio_vsock::VsockListener::bind(tokio_vsock::VsockAddr::new(1, u32::MAX))?;
1763
1764 let addr = listener.local_addr()?;
1765 let addr = format!("vsock:cid={},port={},guid={guid}", addr.cid(), addr.port());
1766
1767 let server = async {
1768 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1769 let server = crate::Task::spawn_blocking(move || listener.incoming().next(), "").await;
1770 #[cfg(feature = "tokio-vsock")]
1771 let server = listener.incoming().next().await;
1772 Builder::vsock_stream(server.unwrap()?)
1773 .server(guid)?
1774 .p2p()
1775 .auth_mechanism(AuthMechanism::Anonymous)
1776 .build()
1777 .await
1778 };
1779
1780 let client = crate::connection::Builder::address(addr.as_str())?
1781 .p2p()
1782 .build();
1783
1784 futures_util::try_join!(server, client)
1785 }
1786
1787 #[cfg(any(
1788 all(feature = "vsock", not(feature = "tokio")),
1789 feature = "tokio-vsock"
1790 ))]
1791 #[test]
1792 #[timeout(15000)]
1793 fn vsock_p2p() {
1794 crate::utils::block_on(test_vsock_p2p()).unwrap();
1795 }
1796
1797 #[cfg(any(
1798 all(feature = "vsock", not(feature = "tokio")),
1799 feature = "tokio-vsock"
1800 ))]
1801 async fn test_vsock_p2p() -> Result<()> {
1802 let (server1, client1) = vsock_p2p_pipe().await?;
1803 let (server2, client2) = vsock_p2p_pipe().await?;
1804
1805 test_p2p(server1, client1, server2, client2).await
1806 }
1807
1808 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1809 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1810 let guid = Guid::generate();
1811
1812 let listener =
1813 vsock::VsockListener::bind_with_cid_port(vsock::VMADDR_CID_LOCAL, u32::MAX).unwrap();
1814 let addr = listener.local_addr().unwrap();
1815 let client = vsock::VsockStream::connect(&addr).unwrap();
1816 let server = listener.incoming().next().unwrap().unwrap();
1817
1818 futures_util::try_join!(
1819 Builder::vsock_stream(server)
1820 .server(guid)
1821 .unwrap()
1822 .p2p()
1823 .auth_mechanism(AuthMechanism::Anonymous)
1824 .build(),
1825 Builder::vsock_stream(client).p2p().build(),
1826 )
1827 }
1828
1829 #[cfg(feature = "tokio-vsock")]
1830 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1831 use futures_util::StreamExt;
1832 use tokio_vsock::VsockAddr;
1833
1834 let guid = Guid::generate();
1835
1836 let listener = tokio_vsock::VsockListener::bind(VsockAddr::new(1, u32::MAX)).unwrap();
1837 let addr = listener.local_addr().unwrap();
1838 let client = tokio_vsock::VsockStream::connect(addr).await.unwrap();
1839 let server = listener.incoming().next().await.unwrap().unwrap();
1840
1841 futures_util::try_join!(
1842 Builder::vsock_stream(server)
1843 .server(guid)
1844 .unwrap()
1845 .p2p()
1846 .auth_mechanism(AuthMechanism::Anonymous)
1847 .build(),
1848 Builder::vsock_stream(client).p2p().build(),
1849 )
1850 }
1851
1852 #[test]
1853 #[timeout(15000)]
1854 fn channel_pair() {
1855 crate::utils::block_on(test_channel_pair()).unwrap();
1856 }
1857
1858 async fn test_channel_pair() -> Result<()> {
1859 let (server1, client1) = create_channel_pair().await;
1860 let (server2, client2) = create_channel_pair().await;
1861
1862 test_p2p(server1, client1, server2, client2).await
1863 }
1864
1865 async fn create_channel_pair() -> (Connection, Connection) {
1866 let (a, b) = socket::Channel::pair();
1867
1868 let guid = crate::Guid::generate();
1869 let conn1 = Builder::authenticated_socket(a, guid.clone())
1870 .unwrap()
1871 .p2p()
1872 .build()
1873 .await
1874 .unwrap();
1875 let conn2 = Builder::authenticated_socket(b, guid)
1876 .unwrap()
1877 .p2p()
1878 .build()
1879 .await
1880 .unwrap();
1881
1882 (conn1, conn2)
1883 }
1884}