1use async_broadcast::{broadcast, InactiveReceiver, Receiver, Sender as Broadcaster};
3use enumflags2::BitFlags;
4use event_listener::{Event, EventListener};
5use ordered_stream::{OrderedFuture, OrderedStream, PollResult};
6use std::{
7 collections::HashMap,
8 io::{self, ErrorKind},
9 num::NonZeroU32,
10 pin::Pin,
11 sync::{Arc, OnceLock, Weak},
12 task::{Context, Poll},
13 time::Duration,
14};
15use tracing::{debug, info_span, instrument, trace, trace_span, warn, Instrument};
16use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, OwnedUniqueName, WellKnownName};
17use zvariant::ObjectPath;
18
19use futures_core::Future;
20use futures_lite::StreamExt;
21
22use crate::{
23 async_lock::{Mutex, Semaphore, SemaphorePermit},
24 fdo::{ConnectionCredentials, ReleaseNameReply, RequestNameFlags, RequestNameReply},
25 is_flatpak,
26 message::{Flags, Message, Type},
27 timeout::timeout,
28 DBusError, Error, Executor, MatchRule, MessageStream, ObjectServer, OwnedGuid, OwnedMatchRule,
29 Result, Task,
30};
31
32mod builder;
33pub use builder::Builder;
34
35pub mod socket;
36pub use socket::Socket;
37
38mod socket_reader;
39use socket_reader::SocketReader;
40
41pub(crate) mod handshake;
42pub use handshake::AuthMechanism;
43use handshake::Authenticated;
44
45const DEFAULT_MAX_QUEUED: usize = 64;
46const DEFAULT_MAX_METHOD_RETURN_QUEUED: usize = 8;
47
48#[derive(Debug)]
50pub(crate) struct ConnectionInner {
51 server_guid: OwnedGuid,
52 #[cfg(unix)]
53 cap_unix_fd: bool,
54 #[cfg(feature = "p2p")]
55 bus_conn: bool,
56 unique_name: OnceLock<OwnedUniqueName>,
57 registered_names: Mutex<HashMap<WellKnownName<'static>, NameStatus>>,
58
59 activity_event: Arc<Event>,
60 socket_write: Mutex<Box<dyn socket::WriteHalf>>,
61
62 executor: Executor<'static>,
64
65 #[allow(unused)]
67 socket_reader_task: OnceLock<Task<()>>,
68
69 pub(crate) msg_receiver: InactiveReceiver<Result<Message>>,
70 pub(crate) method_return_receiver: InactiveReceiver<Result<Message>>,
71 msg_senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
72
73 subscriptions: Mutex<Subscriptions>,
74
75 object_server: OnceLock<ObjectServer>,
76 object_server_dispatch_task: OnceLock<Task<()>>,
77
78 drop_event: Event,
79
80 method_timeout: Option<Duration>,
81}
82
83impl Drop for ConnectionInner {
84 fn drop(&mut self) {
85 self.drop_event.notify(usize::MAX);
89 }
90}
91
92type Subscriptions = HashMap<OwnedMatchRule, (u64, InactiveReceiver<Result<Message>>)>;
93
94pub(crate) type MsgBroadcaster = Broadcaster<Result<Message>>;
95
96#[derive(Clone, Debug)]
211#[must_use = "Dropping a `Connection` will close the underlying socket."]
212pub struct Connection {
213 pub(crate) inner: Arc<ConnectionInner>,
214}
215
216#[derive(Debug)]
222pub(crate) struct PendingMethodCall {
223 stream: Option<MessageStream>,
224 serial: NonZeroU32,
225}
226
227impl Future for PendingMethodCall {
228 type Output = Result<Message>;
229
230 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
231 self.poll_before(cx, None).map(|ret| {
232 ret.map(|(_, r)| r).unwrap_or_else(|| {
233 Err(crate::Error::InputOutput(
234 io::Error::new(ErrorKind::BrokenPipe, "socket closed").into(),
235 ))
236 })
237 })
238 }
239}
240
241impl OrderedFuture for PendingMethodCall {
242 type Output = Result<Message>;
243 type Ordering = zbus::message::Sequence;
244
245 fn poll_before(
246 self: Pin<&mut Self>,
247 cx: &mut Context<'_>,
248 before: Option<&Self::Ordering>,
249 ) -> Poll<Option<(Self::Ordering, Self::Output)>> {
250 let this = self.get_mut();
251 if let Some(stream) = &mut this.stream {
252 loop {
253 match Pin::new(&mut *stream).poll_next_before(cx, before) {
254 Poll::Ready(PollResult::Item {
255 data: Ok(msg),
256 ordering,
257 }) => {
258 if msg.header().reply_serial() != Some(this.serial) {
259 continue;
260 }
261 let res = match msg.message_type() {
262 Type::Error => Err(msg.into()),
263 Type::MethodReturn => Ok(msg),
264 _ => continue,
265 };
266 this.stream = None;
267 return Poll::Ready(Some((ordering, res)));
268 }
269 Poll::Ready(PollResult::Item {
270 data: Err(e),
271 ordering,
272 }) => {
273 return Poll::Ready(Some((ordering, Err(e))));
274 }
275
276 Poll::Ready(PollResult::NoneBefore) => {
277 return Poll::Ready(None);
278 }
279 Poll::Ready(PollResult::Terminated) => {
280 return Poll::Ready(None);
281 }
282 Poll::Pending => return Poll::Pending,
283 }
284 }
285 }
286 Poll::Ready(None)
287 }
288}
289
290impl Connection {
291 pub async fn send(&self, msg: &Message) -> Result<()> {
293 #[cfg(unix)]
294 if !msg.data().fds().is_empty() && !self.inner.cap_unix_fd {
295 return Err(Error::Unsupported);
296 }
297
298 self.inner.activity_event.notify(usize::MAX);
299 let mut write = self.inner.socket_write.lock().await;
300
301 write.send_message(msg).await
302 }
303
304 pub async fn call_method<'d, 'p, 'i, 'm, D, P, I, M, B>(
311 &self,
312 destination: Option<D>,
313 path: P,
314 interface: Option<I>,
315 method_name: M,
316 body: &B,
317 ) -> Result<Message>
318 where
319 D: TryInto<BusName<'d>>,
320 P: TryInto<ObjectPath<'p>>,
321 I: TryInto<InterfaceName<'i>>,
322 M: TryInto<MemberName<'m>>,
323 D::Error: Into<Error>,
324 P::Error: Into<Error>,
325 I::Error: Into<Error>,
326 M::Error: Into<Error>,
327 B: serde::ser::Serialize + zvariant::DynamicType,
328 {
329 let method = self
330 .call_method_raw(
331 destination,
332 path,
333 interface,
334 method_name,
335 BitFlags::empty(),
336 body,
337 )
338 .await?
339 .expect("no reply");
340
341 if let Some(tout) = self.method_timeout() {
342 timeout(method, tout).await
343 } else {
344 method.await
345 }
346 }
347
348 pub(crate) async fn call_method_raw<'d, 'p, 'i, 'm, D, P, I, M, B>(
359 &self,
360 destination: Option<D>,
361 path: P,
362 interface: Option<I>,
363 method_name: M,
364 flags: BitFlags<Flags>,
365 body: &B,
366 ) -> Result<Option<PendingMethodCall>>
367 where
368 D: TryInto<BusName<'d>>,
369 P: TryInto<ObjectPath<'p>>,
370 I: TryInto<InterfaceName<'i>>,
371 M: TryInto<MemberName<'m>>,
372 D::Error: Into<Error>,
373 P::Error: Into<Error>,
374 I::Error: Into<Error>,
375 M::Error: Into<Error>,
376 B: serde::ser::Serialize + zvariant::DynamicType,
377 {
378 let _permit = acquire_serial_num_semaphore().await;
379
380 let mut builder = Message::method_call(path, method_name)?;
381 if let Some(sender) = self.unique_name() {
382 builder = builder.sender(sender)?
383 }
384 if let Some(destination) = destination {
385 builder = builder.destination(destination)?
386 }
387 if let Some(interface) = interface {
388 builder = builder.interface(interface)?
389 }
390 for flag in flags {
391 builder = builder.with_flags(flag)?;
392 }
393 let msg = builder.build(body)?;
394
395 let msg_receiver = self.inner.method_return_receiver.activate_cloned();
396 let stream = Some(MessageStream::for_subscription_channel(
397 msg_receiver,
398 None,
400 self,
401 ));
402 let serial = msg.primary_header().serial_num();
403 self.send(&msg).await?;
404 if flags.contains(Flags::NoReplyExpected) {
405 Ok(None)
406 } else {
407 Ok(Some(PendingMethodCall { stream, serial }))
408 }
409 }
410
411 pub async fn emit_signal<'d, 'p, 'i, 'm, D, P, I, M, B>(
415 &self,
416 destination: Option<D>,
417 path: P,
418 interface: I,
419 signal_name: M,
420 body: &B,
421 ) -> Result<()>
422 where
423 D: TryInto<BusName<'d>>,
424 P: TryInto<ObjectPath<'p>>,
425 I: TryInto<InterfaceName<'i>>,
426 M: TryInto<MemberName<'m>>,
427 D::Error: Into<Error>,
428 P::Error: Into<Error>,
429 I::Error: Into<Error>,
430 M::Error: Into<Error>,
431 B: serde::ser::Serialize + zvariant::DynamicType,
432 {
433 let _permit = acquire_serial_num_semaphore().await;
434
435 let mut b = Message::signal(path, interface, signal_name)?;
436 if let Some(sender) = self.unique_name() {
437 b = b.sender(sender)?;
438 }
439 if let Some(destination) = destination {
440 b = b.destination(destination)?;
441 }
442 let m = b.build(body)?;
443
444 self.send(&m).await
445 }
446
447 pub async fn reply<B>(&self, call: &zbus::message::Header<'_>, body: &B) -> Result<()>
452 where
453 B: serde::ser::Serialize + zvariant::DynamicType,
454 {
455 let _permit = acquire_serial_num_semaphore().await;
456
457 let mut b = Message::method_return(call)?;
458 if let Some(sender) = self.unique_name() {
459 b = b.sender(sender)?;
460 }
461 let m = b.build(body)?;
462 self.send(&m).await
463 }
464
465 pub async fn reply_error<'e, E, B>(
470 &self,
471 call: &zbus::message::Header<'_>,
472 error_name: E,
473 body: &B,
474 ) -> Result<()>
475 where
476 B: serde::ser::Serialize + zvariant::DynamicType,
477 E: TryInto<ErrorName<'e>>,
478 E::Error: Into<Error>,
479 {
480 let _permit = acquire_serial_num_semaphore().await;
481
482 let mut b = Message::error(call, error_name)?;
483 if let Some(sender) = self.unique_name() {
484 b = b.sender(sender)?;
485 }
486 let m = b.build(body)?;
487 self.send(&m).await
488 }
489
490 pub async fn reply_dbus_error(
495 &self,
496 call: &zbus::message::Header<'_>,
497 err: impl DBusError,
498 ) -> Result<()> {
499 let _permit = acquire_serial_num_semaphore().await;
500
501 let m = err.create_reply(call)?;
502 self.send(&m).await
503 }
504
505 pub async fn request_name<'w, W>(&self, well_known_name: W) -> Result<()>
535 where
536 W: TryInto<WellKnownName<'w>>,
537 W::Error: Into<Error>,
538 {
539 self.request_name_with_flags(well_known_name, BitFlags::default())
540 .await
541 .map(|_| ())
542 }
543
544 pub async fn request_name_with_flags<'w, W>(
618 &self,
619 well_known_name: W,
620 flags: BitFlags<RequestNameFlags>,
621 ) -> Result<RequestNameReply>
622 where
623 W: TryInto<WellKnownName<'w>>,
624 W::Error: Into<Error>,
625 {
626 let well_known_name = well_known_name.try_into().map_err(Into::into)?;
627 let mut names = self.inner.registered_names.lock().await;
630
631 match names.get(&well_known_name) {
632 Some(NameStatus::Owner(_)) => return Ok(RequestNameReply::AlreadyOwner),
633 Some(NameStatus::Queued(_)) => return Ok(RequestNameReply::InQueue),
634 None => (),
635 }
636
637 if !self.is_bus() {
638 names.insert(well_known_name.to_owned(), NameStatus::Owner(None));
639
640 return Ok(RequestNameReply::PrimaryOwner);
641 }
642
643 let acquired_match_rule = MatchRule::fdo_signal_builder("NameAcquired")
644 .arg(0, well_known_name.as_ref())
645 .unwrap()
646 .build();
647 let mut acquired_stream = self.add_match(acquired_match_rule.into(), None).await?;
648 let lost_match_rule = MatchRule::fdo_signal_builder("NameLost")
649 .arg(0, well_known_name.as_ref())
650 .unwrap()
651 .build();
652 let mut lost_stream = self.add_match(lost_match_rule.into(), None).await?;
653 let reply = self
654 .call_method(
655 Some("org.freedesktop.DBus"),
656 "/org/freedesktop/DBus",
657 Some("org.freedesktop.DBus"),
658 "RequestName",
659 &(well_known_name.clone(), flags),
660 )
661 .await?
662 .body()
663 .deserialize::<RequestNameReply>()?;
664 let lost_task_name = format!("monitor name {well_known_name} lost");
665 let name_lost_fut = if flags.contains(RequestNameFlags::AllowReplacement) {
666 let weak_conn = WeakConnection::from(self);
667 let well_known_name = well_known_name.to_owned();
668 Some(
669 async move {
670 loop {
671 let signal = lost_stream.next().await;
672 let inner = match weak_conn.upgrade() {
673 Some(conn) => conn.inner.clone(),
674 None => break,
675 };
676
677 match signal {
678 Some(signal) => match signal {
679 Ok(_) => {
680 tracing::info!(
681 "Connection `{}` lost name `{}`",
682 inner.unique_name.get().unwrap(),
685 well_known_name
686 );
687 inner.registered_names.lock().await.remove(&well_known_name);
688
689 break;
690 }
691 Err(e) => warn!("Failed to parse `NameLost` signal: {}", e),
692 },
693 None => {
694 trace!("`NameLost` signal stream closed");
695 break;
704 }
705 }
706 }
707 }
708 .instrument(info_span!("{}", lost_task_name)),
709 )
710 } else {
711 None
712 };
713 let status = match reply {
714 RequestNameReply::InQueue => {
715 let weak_conn = WeakConnection::from(self);
716 let well_known_name = well_known_name.to_owned();
717 let task_name = format!("monitor name {well_known_name} acquired");
718 let task = self.executor().spawn(
719 async move {
720 loop {
721 let signal = acquired_stream.next().await;
722 let inner = match weak_conn.upgrade() {
723 Some(conn) => conn.inner.clone(),
724 None => break,
725 };
726 match signal {
727 Some(signal) => match signal {
728 Ok(_) => {
729 let mut names = inner.registered_names.lock().await;
730 if let Some(status) = names.get_mut(&well_known_name) {
731 let task = name_lost_fut.map(|fut| {
732 inner.executor.spawn(fut, &lost_task_name)
733 });
734 *status = NameStatus::Owner(task);
735
736 break;
737 }
738 }
740 Err(e) => warn!("Failed to parse `NameAcquired` signal: {}", e),
741 },
742 None => {
743 trace!("`NameAcquired` signal stream closed");
744 break;
747 }
748 }
749 }
750 }
751 .instrument(info_span!("{}", task_name)),
752 &task_name,
753 );
754
755 NameStatus::Queued(task)
756 }
757 RequestNameReply::PrimaryOwner | RequestNameReply::AlreadyOwner => {
758 let task = name_lost_fut.map(|fut| self.executor().spawn(fut, &lost_task_name));
759
760 NameStatus::Owner(task)
761 }
762 RequestNameReply::Exists => return Err(Error::NameTaken),
763 };
764
765 names.insert(well_known_name.to_owned(), status);
766
767 Ok(reply)
768 }
769
770 pub async fn release_name<'w, W>(&self, well_known_name: W) -> Result<bool>
779 where
780 W: TryInto<WellKnownName<'w>>,
781 W::Error: Into<Error>,
782 {
783 let well_known_name: WellKnownName<'w> = well_known_name.try_into().map_err(Into::into)?;
784 let mut names = self.inner.registered_names.lock().await;
785 if names.remove(&well_known_name.to_owned()).is_none() {
787 return Ok(false);
788 };
789
790 if !self.is_bus() {
791 return Ok(true);
792 }
793
794 self.call_method(
795 Some("org.freedesktop.DBus"),
796 "/org/freedesktop/DBus",
797 Some("org.freedesktop.DBus"),
798 "ReleaseName",
799 &well_known_name,
800 )
801 .await?
802 .body()
803 .deserialize::<ReleaseNameReply>()
804 .map(|r| r == ReleaseNameReply::Released)
805 }
806
807 pub fn is_bus(&self) -> bool {
812 #[cfg(feature = "p2p")]
813 {
814 self.inner.bus_conn
815 }
816 #[cfg(not(feature = "p2p"))]
817 {
818 true
819 }
820 }
821
822 pub fn unique_name(&self) -> Option<&OwnedUniqueName> {
827 self.inner.unique_name.get()
828 }
829
830 #[cfg(feature = "bus-impl")]
840 pub fn set_unique_name<U>(&self, unique_name: U) -> Result<()>
841 where
842 U: TryInto<OwnedUniqueName>,
843 U::Error: Into<Error>,
844 {
845 let name = unique_name.try_into().map_err(Into::into)?;
846 self.set_unique_name_(name);
847
848 Ok(())
849 }
850
851 pub fn max_queued(&self) -> usize {
853 self.inner.msg_receiver.capacity()
854 }
855
856 pub fn set_max_queued(&mut self, max: usize) {
858 self.inner.msg_receiver.clone().set_capacity(max);
859 }
860
861 pub fn server_guid(&self) -> &OwnedGuid {
863 &self.inner.server_guid
864 }
865
866 pub fn executor(&self) -> &Executor<'static> {
918 &self.inner.executor
919 }
920
921 pub fn object_server(&self) -> &ObjectServer {
929 self.ensure_object_server(true)
930 }
931
932 pub(crate) fn ensure_object_server(&self, start: bool) -> &ObjectServer {
933 self.inner
934 .object_server
935 .get_or_init(move || self.setup_object_server(start, None))
936 }
937
938 fn setup_object_server(&self, start: bool, started_event: Option<Event>) -> ObjectServer {
939 if start {
940 self.start_object_server(started_event);
941 }
942
943 ObjectServer::new(self)
944 }
945
946 #[instrument(skip(self))]
947 pub(crate) fn start_object_server(&self, started_event: Option<Event>) {
948 self.inner.object_server_dispatch_task.get_or_init(|| {
949 trace!("starting ObjectServer task");
950 let weak_conn = WeakConnection::from(self);
951
952 let obj_server_task_name = "ObjectServer task";
953 self.inner.executor.spawn(
954 async move {
955 let mut stream = match weak_conn.upgrade() {
956 Some(conn) => {
957 let mut builder = MatchRule::builder().msg_type(Type::MethodCall);
958 if let Some(unique_name) = conn.unique_name() {
959 builder = builder.destination(&**unique_name).expect("unique name");
960 }
961 let rule = builder.build();
962 match conn.add_match(rule.into(), None).await {
963 Ok(stream) => stream,
964 Err(e) => {
965 debug!("Failed to create message stream: {}", e);
967
968 return;
969 }
970 }
971 }
972 None => {
973 trace!("Connection is gone, stopping associated object server task");
974
975 return;
976 }
977 };
978 if let Some(started_event) = started_event {
979 started_event.notify(1);
980 }
981
982 trace!("waiting for incoming method call messages..");
983 while let Some(msg) = stream.next().await.and_then(|m| {
984 if let Err(e) = &m {
985 debug!("Error while reading from object server stream: {:?}", e);
986 }
987 m.ok()
988 }) {
989 if let Some(conn) = weak_conn.upgrade() {
990 let hdr = msg.header();
991 if !conn.is_bus() {
994 match hdr.destination() {
995 Some(BusName::Unique(_)) | None => (),
997 Some(BusName::WellKnown(dest)) => {
998 let names = conn.inner.registered_names.lock().await;
999 if !names.is_empty() && !names.contains_key(dest) {
1003 trace!(
1004 "Got a method call for a different destination: {}",
1005 dest
1006 );
1007
1008 continue;
1009 }
1010 }
1011 }
1012 }
1013 let server = conn.object_server();
1014 if let Err(e) = server.dispatch_call(&msg, &hdr).await {
1015 debug!(
1016 "Error dispatching message. Message: {:?}, error: {:?}",
1017 msg, e
1018 );
1019 }
1020 } else {
1021 trace!("Connection is gone, stopping associated object server task");
1024 break;
1025 }
1026 }
1027 }
1028 .instrument(info_span!("{}", obj_server_task_name)),
1029 obj_server_task_name,
1030 )
1031 });
1032 }
1033
1034 pub(crate) async fn add_match(
1035 &self,
1036 rule: OwnedMatchRule,
1037 max_queued: Option<usize>,
1038 ) -> Result<Receiver<Result<Message>>> {
1039 use std::collections::hash_map::Entry;
1040
1041 if self.inner.msg_senders.lock().await.is_empty() {
1042 return Err(Error::InputOutput(Arc::new(io::Error::new(
1044 io::ErrorKind::BrokenPipe,
1045 "Socket reader task has errored out",
1046 ))));
1047 }
1048
1049 let mut subscriptions = self.inner.subscriptions.lock().await;
1050 let msg_type = rule.msg_type().unwrap_or(Type::Signal);
1051 match subscriptions.entry(rule.clone()) {
1052 Entry::Vacant(e) => {
1053 let max_queued = max_queued.unwrap_or(DEFAULT_MAX_QUEUED);
1054 let (sender, mut receiver) = broadcast(max_queued);
1055 receiver.set_await_active(false);
1056 if self.is_bus() && msg_type == Type::Signal {
1057 self.call_method(
1058 Some("org.freedesktop.DBus"),
1059 "/org/freedesktop/DBus",
1060 Some("org.freedesktop.DBus"),
1061 "AddMatch",
1062 &e.key(),
1063 )
1064 .await?;
1065 }
1066 e.insert((1, receiver.clone().deactivate()));
1067 self.inner
1068 .msg_senders
1069 .lock()
1070 .await
1071 .insert(Some(rule), sender);
1072
1073 Ok(receiver)
1074 }
1075 Entry::Occupied(mut e) => {
1076 let (num_subscriptions, receiver) = e.get_mut();
1077 *num_subscriptions += 1;
1078 if let Some(max_queued) = max_queued {
1079 if max_queued > receiver.capacity() {
1080 receiver.set_capacity(max_queued);
1081 }
1082 }
1083
1084 Ok(receiver.activate_cloned())
1085 }
1086 }
1087 }
1088
1089 pub(crate) async fn remove_match(&self, rule: OwnedMatchRule) -> Result<bool> {
1090 use std::collections::hash_map::Entry;
1091 let mut subscriptions = self.inner.subscriptions.lock().await;
1092 let msg_type = rule.msg_type().unwrap_or(Type::Signal);
1095 match subscriptions.entry(rule) {
1096 Entry::Vacant(_) => Ok(false),
1097 Entry::Occupied(mut e) => {
1098 let rule = e.key().inner().clone();
1099 e.get_mut().0 -= 1;
1100 if e.get().0 == 0 {
1101 if self.is_bus() && msg_type == Type::Signal {
1102 self.call_method(
1103 Some("org.freedesktop.DBus"),
1104 "/org/freedesktop/DBus",
1105 Some("org.freedesktop.DBus"),
1106 "RemoveMatch",
1107 &rule,
1108 )
1109 .await?;
1110 }
1111 e.remove();
1112 self.inner
1113 .msg_senders
1114 .lock()
1115 .await
1116 .remove(&Some(rule.into()));
1117 }
1118 Ok(true)
1119 }
1120 }
1121 }
1122
1123 pub(crate) fn queue_remove_match(&self, rule: OwnedMatchRule) {
1124 let conn = self.clone();
1125 let task_name = format!("Remove match `{}`", *rule);
1126 let remove_match =
1127 async move { conn.remove_match(rule).await }.instrument(trace_span!("{}", task_name));
1128 self.inner.executor.spawn(remove_match, &task_name).detach()
1129 }
1130
1131 pub fn method_timeout(&self) -> Option<Duration> {
1133 self.inner.method_timeout
1134 }
1135
1136 pub(crate) async fn new(
1137 auth: Authenticated,
1138 #[allow(unused)] bus_connection: bool,
1139 executor: Executor<'static>,
1140 method_timeout: Option<Duration>,
1141 ) -> Result<Self> {
1142 #[cfg(unix)]
1143 let cap_unix_fd = auth.cap_unix_fd;
1144
1145 macro_rules! create_msg_broadcast_channel {
1146 ($size:expr) => {{
1147 let (msg_sender, msg_receiver) = broadcast($size);
1148 let mut msg_receiver = msg_receiver.deactivate();
1149 msg_receiver.set_await_active(false);
1150
1151 (msg_sender, msg_receiver)
1152 }};
1153 }
1154 let (msg_sender, msg_receiver) = create_msg_broadcast_channel!(DEFAULT_MAX_QUEUED);
1156 let mut msg_senders = HashMap::new();
1157 msg_senders.insert(None, msg_sender);
1158
1159 let (method_return_sender, method_return_receiver) =
1161 create_msg_broadcast_channel!(DEFAULT_MAX_METHOD_RETURN_QUEUED);
1162 let rule = MatchRule::builder()
1163 .msg_type(Type::MethodReturn)
1164 .build()
1165 .into();
1166 msg_senders.insert(Some(rule), method_return_sender.clone());
1167 let rule = MatchRule::builder().msg_type(Type::Error).build().into();
1168 msg_senders.insert(Some(rule), method_return_sender);
1169 let msg_senders = Arc::new(Mutex::new(msg_senders));
1170 let subscriptions = Mutex::new(HashMap::new());
1171
1172 let connection = Self {
1173 inner: Arc::new(ConnectionInner {
1174 activity_event: Arc::new(Event::new()),
1175 socket_write: Mutex::new(auth.socket_write),
1176 server_guid: auth.server_guid,
1177 #[cfg(unix)]
1178 cap_unix_fd,
1179 #[cfg(feature = "p2p")]
1180 bus_conn: bus_connection,
1181 unique_name: OnceLock::new(),
1182 subscriptions,
1183 object_server: OnceLock::new(),
1184 object_server_dispatch_task: OnceLock::new(),
1185 executor,
1186 socket_reader_task: OnceLock::new(),
1187 msg_senders,
1188 msg_receiver,
1189 method_return_receiver,
1190 registered_names: Mutex::new(HashMap::new()),
1191 drop_event: Event::new(),
1192 method_timeout,
1193 }),
1194 };
1195
1196 if let Some(unique_name) = auth.unique_name {
1197 connection.set_unique_name_(unique_name);
1198 }
1199
1200 Ok(connection)
1201 }
1202
1203 pub async fn session() -> Result<Self> {
1205 Builder::session()?.build().await
1206 }
1207
1208 pub async fn system() -> Result<Self> {
1210 Builder::system()?.build().await
1211 }
1212
1213 pub fn monitor_activity(&self) -> EventListener {
1217 self.inner.activity_event.listen()
1218 }
1219
1220 pub async fn peer_credentials(&self) -> io::Result<ConnectionCredentials> {
1229 self.inner
1230 .socket_write
1231 .lock()
1232 .await
1233 .peer_credentials()
1234 .await
1235 }
1236
1237 pub async fn close(self) -> Result<()> {
1241 self.inner.activity_event.notify(usize::MAX);
1242 self.inner
1243 .socket_write
1244 .lock()
1245 .await
1246 .close()
1247 .await
1248 .map_err(Into::into)
1249 }
1250
1251 pub async fn graceful_shutdown(self) {
1295 let listener = self.inner.drop_event.listen();
1296 drop(self);
1297 listener.await;
1298 }
1299
1300 pub(crate) fn init_socket_reader(
1301 &self,
1302 socket_read: Box<dyn socket::ReadHalf>,
1303 already_read: Vec<u8>,
1304 #[cfg(unix)] already_received_fds: Vec<std::os::fd::OwnedFd>,
1305 ) {
1306 let inner = &self.inner;
1307 inner
1308 .socket_reader_task
1309 .set(
1310 SocketReader::new(
1311 socket_read,
1312 inner.msg_senders.clone(),
1313 already_read,
1314 #[cfg(unix)]
1315 already_received_fds,
1316 inner.activity_event.clone(),
1317 )
1318 .spawn(&inner.executor),
1319 )
1320 .expect("Attempted to set `socket_reader_task` twice");
1321 }
1322
1323 fn set_unique_name_(&self, name: OwnedUniqueName) {
1324 self.inner
1325 .unique_name
1326 .set(name)
1327 .expect("unique name already set");
1329 }
1330}
1331
1332#[cfg(feature = "blocking-api")]
1333impl From<crate::blocking::Connection> for Connection {
1334 fn from(conn: crate::blocking::Connection) -> Self {
1335 conn.into_inner()
1336 }
1337}
1338
1339#[derive(Debug, Clone)]
1341pub(crate) struct WeakConnection {
1342 inner: Weak<ConnectionInner>,
1343}
1344
1345impl WeakConnection {
1346 pub fn upgrade(&self) -> Option<Connection> {
1348 self.inner.upgrade().map(|inner| Connection { inner })
1349 }
1350}
1351
1352impl From<&Connection> for WeakConnection {
1353 fn from(conn: &Connection) -> Self {
1354 Self {
1355 inner: Arc::downgrade(&conn.inner),
1356 }
1357 }
1358}
1359
1360#[derive(Debug)]
1361enum NameStatus {
1362 Owner(#[allow(unused)] Option<Task<()>>),
1364 Queued(#[allow(unused)] Task<()>),
1366}
1367
1368static SERIAL_NUM_SEMAPHORE: Semaphore = Semaphore::new(1);
1369
1370async fn acquire_serial_num_semaphore() -> Option<SemaphorePermit<'static>> {
1375 if is_flatpak() {
1376 Some(SERIAL_NUM_SEMAPHORE.acquire().await)
1377 } else {
1378 None
1379 }
1380}
1381
1382#[cfg(test)]
1383mod tests {
1384 use super::*;
1385 use crate::fdo::DBusProxy;
1386 use ntest::timeout;
1387 use std::{pin::pin, time::Duration};
1388 use test_log::test;
1389
1390 #[cfg(windows)]
1391 #[test]
1392 fn connect_autolaunch_session_bus() {
1393 let addr =
1394 crate::win32::autolaunch_bus_address().expect("Unable to get session bus address");
1395
1396 crate::block_on(async { addr.connect().await }).expect("Unable to connect to session bus");
1397 }
1398
1399 #[cfg(target_os = "macos")]
1400 #[test]
1401 fn connect_launchd_session_bus() {
1402 use crate::address::{transport::Launchd, Address, Transport};
1403 crate::block_on(async {
1404 let addr = Address::from(Transport::Launchd(Launchd::new(
1405 "DBUS_LAUNCHD_SESSION_BUS_SOCKET",
1406 )));
1407 addr.connect().await
1408 })
1409 .expect("Unable to connect to session bus");
1410 }
1411
1412 #[test]
1413 #[timeout(15000)]
1414 fn disconnect_on_drop() {
1415 crate::utils::block_on(test_disconnect_on_drop());
1418 }
1419
1420 async fn test_disconnect_on_drop() {
1421 #[derive(Default)]
1422 struct MyInterface {}
1423
1424 #[crate::interface(name = "dev.peelz.FooBar.Baz")]
1425 impl MyInterface {
1426 fn do_thing(&self) {}
1427 }
1428 let name = "dev.peelz.foobar";
1429 let connection = Builder::session()
1430 .unwrap()
1431 .name(name)
1432 .unwrap()
1433 .serve_at("/dev/peelz/FooBar", MyInterface::default())
1434 .unwrap()
1435 .build()
1436 .await
1437 .unwrap();
1438
1439 let connection2 = Connection::session().await.unwrap();
1440 let dbus = DBusProxy::new(&connection2).await.unwrap();
1441 let mut stream = dbus
1442 .receive_name_owner_changed_with_args(&[(0, name), (2, "")])
1443 .await
1444 .unwrap();
1445
1446 drop(connection);
1447
1448 stream.next().await.unwrap();
1450
1451 let name_has_owner = dbus.name_has_owner(name.try_into().unwrap()).await.unwrap();
1453 assert!(!name_has_owner);
1454 }
1455
1456 #[tokio::test(start_paused = true)]
1457 #[timeout(15000)]
1458 async fn test_graceful_shutdown() {
1459 let connection = Connection::session().await.unwrap();
1461 let clone = connection.clone();
1462 let mut shutdown = pin!(connection.graceful_shutdown());
1463 tokio::select! {
1466 _ = tokio::time::sleep(Duration::from_secs(u64::MAX)) => {},
1467 _ = &mut shutdown => {
1468 panic!("Graceful shutdown unexpectedly completed");
1469 }
1470 }
1471
1472 drop(clone);
1473 shutdown.await;
1474
1475 struct GracefulInterface {
1477 method_called: Event,
1478 wait_before_return: Option<EventListener>,
1479 announce_done: Event,
1480 }
1481
1482 #[crate::interface(name = "dev.peelz.TestGracefulShutdown")]
1483 impl GracefulInterface {
1484 async fn do_thing(&mut self) {
1485 self.method_called.notify(1);
1486 if let Some(listener) = self.wait_before_return.take() {
1487 listener.await;
1488 }
1489 self.announce_done.notify(1);
1490 }
1491 }
1492
1493 let method_called = Event::new();
1494 let method_called_listener = method_called.listen();
1495
1496 let trigger_return = Event::new();
1497 let wait_before_return = Some(trigger_return.listen());
1498
1499 let announce_done = Event::new();
1500 let done_listener = announce_done.listen();
1501
1502 let interface = GracefulInterface {
1503 method_called,
1504 wait_before_return,
1505 announce_done,
1506 };
1507
1508 let name = "dev.peelz.TestGracefulShutdown";
1509 let obj = "/dev/peelz/TestGracefulShutdown";
1510 let connection = Builder::session()
1511 .unwrap()
1512 .name(name)
1513 .unwrap()
1514 .serve_at(obj, interface)
1515 .unwrap()
1516 .build()
1517 .await
1518 .unwrap();
1519
1520 let client_conn = Connection::session().await.unwrap();
1522 tokio::spawn(async move {
1523 client_conn
1524 .call_method(Some(name), obj, Some(name), "DoThing", &())
1525 .await
1526 .unwrap();
1527 });
1528
1529 method_called_listener.await;
1532
1533 let mut shutdown = pin!(connection.graceful_shutdown());
1534 tokio::select! {
1535 _ = tokio::time::sleep(Duration::from_secs(u64::MAX)) => {},
1536 _ = &mut shutdown => {
1537 panic!("Graceful shutdown unexpectedly completed");
1539 }
1540 }
1541
1542 trigger_return.notify(1);
1544 shutdown.await;
1545
1546 done_listener.await;
1548 }
1549}
1550
1551#[cfg(feature = "p2p")]
1552#[cfg(test)]
1553mod p2p_tests {
1554 use event_listener::Event;
1555 use futures_util::TryStreamExt;
1556 use ntest::timeout;
1557 use test_log::test;
1558 use zvariant::{Endian, NATIVE_ENDIAN};
1559
1560 use super::{socket, Builder, Connection};
1561 use crate::{conn::AuthMechanism, Guid, Message, MessageStream, Result};
1562
1563 async fn test_p2p(
1565 server1: Connection,
1566 client1: Connection,
1567 server2: Connection,
1568 client2: Connection,
1569 ) -> Result<()> {
1570 let forward1 = {
1571 let stream = MessageStream::from(server1.clone());
1572 let sink = client2.clone();
1573
1574 stream.try_for_each(move |msg| {
1575 let sink = sink.clone();
1576 async move { sink.send(&msg).await }
1577 })
1578 };
1579 let forward2 = {
1580 let stream = MessageStream::from(client2.clone());
1581 let sink = server1.clone();
1582
1583 stream.try_for_each(move |msg| {
1584 let sink = sink.clone();
1585 async move { sink.send(&msg).await }
1586 })
1587 };
1588 let _forward_task = client1.executor().spawn(
1589 async move { futures_util::try_join!(forward1, forward2) },
1590 "forward_task",
1591 );
1592
1593 let server_ready = Event::new();
1594 let server_ready_listener = server_ready.listen();
1595 let client_done = Event::new();
1596 let client_done_listener = client_done.listen();
1597
1598 let server_future = async move {
1599 let mut stream = MessageStream::from(&server2);
1600 server_ready.notify(1);
1601 let method = loop {
1602 let m = stream.try_next().await?.unwrap();
1603 if m.to_string() == "Method call Test" {
1604 assert_eq!(m.body().deserialize::<u64>().unwrap(), 64);
1605 break m;
1606 }
1607 };
1608
1609 server2
1611 .emit_signal(None::<()>, "/", "org.zbus.p2p", "ASignalForYou", &())
1612 .await?;
1613 server2.reply(&method.header(), &("yay")).await?;
1614 client_done_listener.await;
1615
1616 Ok(())
1617 };
1618
1619 let client_future = async move {
1620 let mut stream = MessageStream::from(&client1);
1621 server_ready_listener.await;
1622 let endian = match NATIVE_ENDIAN {
1626 Endian::Little => Endian::Big,
1627 Endian::Big => Endian::Little,
1628 };
1629 let method = Message::method_call("/", "Test")?
1630 .interface("org.zbus.p2p")?
1631 .endian(endian)
1632 .build(&64u64)?;
1633 client1.send(&method).await?;
1634 let m = stream.try_next().await?.unwrap();
1636 client_done.notify(1);
1637 assert_eq!(m.to_string(), "Signal ASignalForYou");
1638 let reply = stream.try_next().await?.unwrap();
1639 assert_eq!(reply.to_string(), "Method return");
1640 assert_eq!(Endian::from(reply.primary_header().endian_sig()), endian);
1642 reply.body().deserialize::<String>()
1643 };
1644
1645 let (val, _) = futures_util::try_join!(client_future, server_future,)?;
1646 assert_eq!(val, "yay");
1647
1648 Ok(())
1649 }
1650
1651 #[test]
1652 #[timeout(15000)]
1653 fn tcp_p2p() {
1654 crate::utils::block_on(test_tcp_p2p()).unwrap();
1655 }
1656
1657 async fn test_tcp_p2p() -> Result<()> {
1658 let (server1, client1) = tcp_p2p_pipe().await?;
1659 let (server2, client2) = tcp_p2p_pipe().await?;
1660
1661 test_p2p(server1, client1, server2, client2).await
1662 }
1663
1664 async fn tcp_p2p_pipe() -> Result<(Connection, Connection)> {
1665 let guid = Guid::generate();
1666
1667 #[cfg(not(feature = "tokio"))]
1668 let (server_conn_builder, client_conn_builder) = {
1669 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
1670 let addr = listener.local_addr().unwrap();
1671 let p1 = std::net::TcpStream::connect(addr).unwrap();
1672 let p0 = listener.incoming().next().unwrap().unwrap();
1673
1674 (
1675 Builder::tcp_stream(p0)
1676 .server(guid)
1677 .unwrap()
1678 .p2p()
1679 .auth_mechanism(AuthMechanism::Anonymous),
1680 Builder::tcp_stream(p1).p2p(),
1681 )
1682 };
1683
1684 #[cfg(feature = "tokio")]
1685 let (server_conn_builder, client_conn_builder) = {
1686 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1687 let addr = listener.local_addr().unwrap();
1688 let p1 = tokio::net::TcpStream::connect(addr).await.unwrap();
1689 let p0 = listener.accept().await.unwrap().0;
1690
1691 (
1692 Builder::tcp_stream(p0)
1693 .server(guid)
1694 .unwrap()
1695 .p2p()
1696 .auth_mechanism(AuthMechanism::Anonymous),
1697 Builder::tcp_stream(p1).p2p(),
1698 )
1699 };
1700
1701 futures_util::try_join!(server_conn_builder.build(), client_conn_builder.build())
1702 }
1703
1704 #[cfg(unix)]
1705 #[test]
1706 #[timeout(15000)]
1707 fn unix_p2p() {
1708 crate::utils::block_on(test_unix_p2p()).unwrap();
1709 }
1710
1711 #[cfg(unix)]
1712 async fn test_unix_p2p() -> Result<()> {
1713 let (server1, client1) = unix_p2p_pipe().await?;
1714 let (server2, client2) = unix_p2p_pipe().await?;
1715
1716 test_p2p(server1, client1, server2, client2).await
1717 }
1718
1719 #[cfg(unix)]
1720 async fn unix_p2p_pipe() -> Result<(Connection, Connection)> {
1721 #[cfg(not(feature = "tokio"))]
1722 use std::os::unix::net::UnixStream;
1723 #[cfg(feature = "tokio")]
1724 use tokio::net::UnixStream;
1725 #[cfg(all(windows, not(feature = "tokio")))]
1726 use uds_windows::UnixStream;
1727
1728 let guid = Guid::generate();
1729
1730 let (p0, p1) = UnixStream::pair().unwrap();
1731
1732 futures_util::try_join!(
1733 Builder::unix_stream(p1).p2p().build(),
1734 Builder::unix_stream(p0).server(guid).unwrap().p2p().build(),
1735 )
1736 }
1737
1738 #[cfg(any(
1739 all(feature = "vsock", not(feature = "tokio")),
1740 feature = "tokio-vsock"
1741 ))]
1742 #[test]
1743 #[timeout(15000)]
1744 fn vsock_connect() {
1745 let _ = crate::utils::block_on(test_vsock_connect()).unwrap();
1746 }
1747
1748 #[cfg(any(
1749 all(feature = "vsock", not(feature = "tokio")),
1750 feature = "tokio-vsock"
1751 ))]
1752 async fn test_vsock_connect() -> Result<(Connection, Connection)> {
1753 #[cfg(feature = "tokio-vsock")]
1754 use futures_util::StreamExt;
1755
1756 let guid = Guid::generate();
1757
1758 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1759 let listener = vsock::VsockListener::bind_with_cid_port(vsock::VMADDR_CID_LOCAL, u32::MAX)?;
1760 #[cfg(feature = "tokio-vsock")]
1761 let listener = tokio_vsock::VsockListener::bind(tokio_vsock::VsockAddr::new(1, u32::MAX))?;
1762
1763 let addr = listener.local_addr()?;
1764 let addr = format!("vsock:cid={},port={},guid={guid}", addr.cid(), addr.port());
1765
1766 let server = async {
1767 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1768 let server = crate::Task::spawn_blocking(move || listener.incoming().next(), "").await;
1769 #[cfg(feature = "tokio-vsock")]
1770 let server = listener.incoming().next().await;
1771 Builder::vsock_stream(server.unwrap()?)
1772 .server(guid)?
1773 .p2p()
1774 .auth_mechanism(AuthMechanism::Anonymous)
1775 .build()
1776 .await
1777 };
1778
1779 let client = crate::connection::Builder::address(addr.as_str())?
1780 .p2p()
1781 .build();
1782
1783 futures_util::try_join!(server, client)
1784 }
1785
1786 #[cfg(any(
1787 all(feature = "vsock", not(feature = "tokio")),
1788 feature = "tokio-vsock"
1789 ))]
1790 #[test]
1791 #[timeout(15000)]
1792 fn vsock_p2p() {
1793 crate::utils::block_on(test_vsock_p2p()).unwrap();
1794 }
1795
1796 #[cfg(any(
1797 all(feature = "vsock", not(feature = "tokio")),
1798 feature = "tokio-vsock"
1799 ))]
1800 async fn test_vsock_p2p() -> Result<()> {
1801 let (server1, client1) = vsock_p2p_pipe().await?;
1802 let (server2, client2) = vsock_p2p_pipe().await?;
1803
1804 test_p2p(server1, client1, server2, client2).await
1805 }
1806
1807 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1808 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1809 let guid = Guid::generate();
1810
1811 let listener =
1812 vsock::VsockListener::bind_with_cid_port(vsock::VMADDR_CID_LOCAL, u32::MAX).unwrap();
1813 let addr = listener.local_addr().unwrap();
1814 let client = vsock::VsockStream::connect(&addr).unwrap();
1815 let server = listener.incoming().next().unwrap().unwrap();
1816
1817 futures_util::try_join!(
1818 Builder::vsock_stream(server)
1819 .server(guid)
1820 .unwrap()
1821 .p2p()
1822 .auth_mechanism(AuthMechanism::Anonymous)
1823 .build(),
1824 Builder::vsock_stream(client).p2p().build(),
1825 )
1826 }
1827
1828 #[cfg(feature = "tokio-vsock")]
1829 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1830 use futures_util::StreamExt;
1831 use tokio_vsock::VsockAddr;
1832
1833 let guid = Guid::generate();
1834
1835 let listener = tokio_vsock::VsockListener::bind(VsockAddr::new(1, u32::MAX)).unwrap();
1836 let addr = listener.local_addr().unwrap();
1837 let client = tokio_vsock::VsockStream::connect(addr).await.unwrap();
1838 let server = listener.incoming().next().await.unwrap().unwrap();
1839
1840 futures_util::try_join!(
1841 Builder::vsock_stream(server)
1842 .server(guid)
1843 .unwrap()
1844 .p2p()
1845 .auth_mechanism(AuthMechanism::Anonymous)
1846 .build(),
1847 Builder::vsock_stream(client).p2p().build(),
1848 )
1849 }
1850
1851 #[test]
1852 #[timeout(15000)]
1853 fn channel_pair() {
1854 crate::utils::block_on(test_channel_pair()).unwrap();
1855 }
1856
1857 async fn test_channel_pair() -> Result<()> {
1858 let (server1, client1) = create_channel_pair().await;
1859 let (server2, client2) = create_channel_pair().await;
1860
1861 test_p2p(server1, client1, server2, client2).await
1862 }
1863
1864 async fn create_channel_pair() -> (Connection, Connection) {
1865 let (a, b) = socket::Channel::pair();
1866
1867 let guid = crate::Guid::generate();
1868 let conn1 = Builder::authenticated_socket(a, guid.clone())
1869 .unwrap()
1870 .p2p()
1871 .build()
1872 .await
1873 .unwrap();
1874 let conn2 = Builder::authenticated_socket(b, guid)
1875 .unwrap()
1876 .p2p()
1877 .build()
1878 .await
1879 .unwrap();
1880
1881 (conn1, conn2)
1882 }
1883}