1use async_broadcast::{broadcast, InactiveReceiver, Receiver, Sender as Broadcaster};
3use enumflags2::BitFlags;
4use event_listener::{Event, EventListener};
5use ordered_stream::{OrderedFuture, OrderedStream, PollResult};
6use std::{
7 collections::HashMap,
8 io::{self, ErrorKind},
9 num::NonZeroU32,
10 pin::Pin,
11 sync::{Arc, OnceLock, Weak},
12 task::{Context, Poll},
13};
14use tracing::{debug, info_span, instrument, trace, trace_span, warn, Instrument};
15use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, OwnedUniqueName, WellKnownName};
16use zvariant::ObjectPath;
17
18use futures_core::Future;
19use futures_lite::StreamExt;
20
21use crate::{
22 async_lock::{Mutex, Semaphore, SemaphorePermit},
23 fdo::{ConnectionCredentials, ReleaseNameReply, RequestNameFlags, RequestNameReply},
24 is_flatpak,
25 message::{Flags, Message, Type},
26 DBusError, Error, Executor, MatchRule, MessageStream, ObjectServer, OwnedGuid, OwnedMatchRule,
27 Result, Task,
28};
29
30mod builder;
31pub use builder::Builder;
32
33pub mod socket;
34pub use socket::Socket;
35
36mod socket_reader;
37use socket_reader::SocketReader;
38
39pub(crate) mod handshake;
40pub use handshake::AuthMechanism;
41use handshake::Authenticated;
42
43const DEFAULT_MAX_QUEUED: usize = 64;
44const DEFAULT_MAX_METHOD_RETURN_QUEUED: usize = 8;
45
46#[derive(Debug)]
48pub(crate) struct ConnectionInner {
49 server_guid: OwnedGuid,
50 #[cfg(unix)]
51 cap_unix_fd: bool,
52 #[cfg(feature = "p2p")]
53 bus_conn: bool,
54 unique_name: OnceLock<OwnedUniqueName>,
55 registered_names: Mutex<HashMap<WellKnownName<'static>, NameStatus>>,
56
57 activity_event: Arc<Event>,
58 socket_write: Mutex<Box<dyn socket::WriteHalf>>,
59
60 executor: Executor<'static>,
62
63 #[allow(unused)]
65 socket_reader_task: OnceLock<Task<()>>,
66
67 pub(crate) msg_receiver: InactiveReceiver<Result<Message>>,
68 pub(crate) method_return_receiver: InactiveReceiver<Result<Message>>,
69 msg_senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
70
71 subscriptions: Mutex<Subscriptions>,
72
73 object_server: OnceLock<ObjectServer>,
74 object_server_dispatch_task: OnceLock<Task<()>>,
75
76 drop_event: Event,
77}
78
79impl Drop for ConnectionInner {
80 fn drop(&mut self) {
81 self.drop_event.notify(usize::MAX);
85 }
86}
87
88type Subscriptions = HashMap<OwnedMatchRule, (u64, InactiveReceiver<Result<Message>>)>;
89
90pub(crate) type MsgBroadcaster = Broadcaster<Result<Message>>;
91
92#[derive(Clone, Debug)]
207#[must_use = "Dropping a `Connection` will close the underlying socket."]
208pub struct Connection {
209 pub(crate) inner: Arc<ConnectionInner>,
210}
211
212#[derive(Debug)]
218pub(crate) struct PendingMethodCall {
219 stream: Option<MessageStream>,
220 serial: NonZeroU32,
221}
222
223impl Future for PendingMethodCall {
224 type Output = Result<Message>;
225
226 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
227 self.poll_before(cx, None).map(|ret| {
228 ret.map(|(_, r)| r).unwrap_or_else(|| {
229 Err(crate::Error::InputOutput(
230 io::Error::new(ErrorKind::BrokenPipe, "socket closed").into(),
231 ))
232 })
233 })
234 }
235}
236
237impl OrderedFuture for PendingMethodCall {
238 type Output = Result<Message>;
239 type Ordering = zbus::message::Sequence;
240
241 fn poll_before(
242 self: Pin<&mut Self>,
243 cx: &mut Context<'_>,
244 before: Option<&Self::Ordering>,
245 ) -> Poll<Option<(Self::Ordering, Self::Output)>> {
246 let this = self.get_mut();
247 if let Some(stream) = &mut this.stream {
248 loop {
249 match Pin::new(&mut *stream).poll_next_before(cx, before) {
250 Poll::Ready(PollResult::Item {
251 data: Ok(msg),
252 ordering,
253 }) => {
254 if msg.header().reply_serial() != Some(this.serial) {
255 continue;
256 }
257 let res = match msg.message_type() {
258 Type::Error => Err(msg.into()),
259 Type::MethodReturn => Ok(msg),
260 _ => continue,
261 };
262 this.stream = None;
263 return Poll::Ready(Some((ordering, res)));
264 }
265 Poll::Ready(PollResult::Item {
266 data: Err(e),
267 ordering,
268 }) => {
269 return Poll::Ready(Some((ordering, Err(e))));
270 }
271
272 Poll::Ready(PollResult::NoneBefore) => {
273 return Poll::Ready(None);
274 }
275 Poll::Ready(PollResult::Terminated) => {
276 return Poll::Ready(None);
277 }
278 Poll::Pending => return Poll::Pending,
279 }
280 }
281 }
282 Poll::Ready(None)
283 }
284}
285
286impl Connection {
287 pub async fn send(&self, msg: &Message) -> Result<()> {
289 #[cfg(unix)]
290 if !msg.data().fds().is_empty() && !self.inner.cap_unix_fd {
291 return Err(Error::Unsupported);
292 }
293
294 self.inner.activity_event.notify(usize::MAX);
295 let mut write = self.inner.socket_write.lock().await;
296
297 write.send_message(msg).await
298 }
299
300 pub async fn call_method<'d, 'p, 'i, 'm, D, P, I, M, B>(
307 &self,
308 destination: Option<D>,
309 path: P,
310 interface: Option<I>,
311 method_name: M,
312 body: &B,
313 ) -> Result<Message>
314 where
315 D: TryInto<BusName<'d>>,
316 P: TryInto<ObjectPath<'p>>,
317 I: TryInto<InterfaceName<'i>>,
318 M: TryInto<MemberName<'m>>,
319 D::Error: Into<Error>,
320 P::Error: Into<Error>,
321 I::Error: Into<Error>,
322 M::Error: Into<Error>,
323 B: serde::ser::Serialize + zvariant::DynamicType,
324 {
325 self.call_method_raw(
326 destination,
327 path,
328 interface,
329 method_name,
330 BitFlags::empty(),
331 body,
332 )
333 .await?
334 .expect("no reply")
335 .await
336 }
337
338 pub(crate) async fn call_method_raw<'d, 'p, 'i, 'm, D, P, I, M, B>(
349 &self,
350 destination: Option<D>,
351 path: P,
352 interface: Option<I>,
353 method_name: M,
354 flags: BitFlags<Flags>,
355 body: &B,
356 ) -> Result<Option<PendingMethodCall>>
357 where
358 D: TryInto<BusName<'d>>,
359 P: TryInto<ObjectPath<'p>>,
360 I: TryInto<InterfaceName<'i>>,
361 M: TryInto<MemberName<'m>>,
362 D::Error: Into<Error>,
363 P::Error: Into<Error>,
364 I::Error: Into<Error>,
365 M::Error: Into<Error>,
366 B: serde::ser::Serialize + zvariant::DynamicType,
367 {
368 let _permit = acquire_serial_num_semaphore().await;
369
370 let mut builder = Message::method_call(path, method_name)?;
371 if let Some(sender) = self.unique_name() {
372 builder = builder.sender(sender)?
373 }
374 if let Some(destination) = destination {
375 builder = builder.destination(destination)?
376 }
377 if let Some(interface) = interface {
378 builder = builder.interface(interface)?
379 }
380 for flag in flags {
381 builder = builder.with_flags(flag)?;
382 }
383 let msg = builder.build(body)?;
384
385 let msg_receiver = self.inner.method_return_receiver.activate_cloned();
386 let stream = Some(MessageStream::for_subscription_channel(
387 msg_receiver,
388 None,
390 self,
391 ));
392 let serial = msg.primary_header().serial_num();
393 self.send(&msg).await?;
394 if flags.contains(Flags::NoReplyExpected) {
395 Ok(None)
396 } else {
397 Ok(Some(PendingMethodCall { stream, serial }))
398 }
399 }
400
401 pub async fn emit_signal<'d, 'p, 'i, 'm, D, P, I, M, B>(
405 &self,
406 destination: Option<D>,
407 path: P,
408 interface: I,
409 signal_name: M,
410 body: &B,
411 ) -> Result<()>
412 where
413 D: TryInto<BusName<'d>>,
414 P: TryInto<ObjectPath<'p>>,
415 I: TryInto<InterfaceName<'i>>,
416 M: TryInto<MemberName<'m>>,
417 D::Error: Into<Error>,
418 P::Error: Into<Error>,
419 I::Error: Into<Error>,
420 M::Error: Into<Error>,
421 B: serde::ser::Serialize + zvariant::DynamicType,
422 {
423 let _permit = acquire_serial_num_semaphore().await;
424
425 let mut b = Message::signal(path, interface, signal_name)?;
426 if let Some(sender) = self.unique_name() {
427 b = b.sender(sender)?;
428 }
429 if let Some(destination) = destination {
430 b = b.destination(destination)?;
431 }
432 let m = b.build(body)?;
433
434 self.send(&m).await
435 }
436
437 pub async fn reply<B>(&self, call: &zbus::message::Header<'_>, body: &B) -> Result<()>
442 where
443 B: serde::ser::Serialize + zvariant::DynamicType,
444 {
445 let _permit = acquire_serial_num_semaphore().await;
446
447 let mut b = Message::method_return(call)?;
448 if let Some(sender) = self.unique_name() {
449 b = b.sender(sender)?;
450 }
451 let m = b.build(body)?;
452 self.send(&m).await
453 }
454
455 pub async fn reply_error<'e, E, B>(
460 &self,
461 call: &zbus::message::Header<'_>,
462 error_name: E,
463 body: &B,
464 ) -> Result<()>
465 where
466 B: serde::ser::Serialize + zvariant::DynamicType,
467 E: TryInto<ErrorName<'e>>,
468 E::Error: Into<Error>,
469 {
470 let _permit = acquire_serial_num_semaphore().await;
471
472 let mut b = Message::error(call, error_name)?;
473 if let Some(sender) = self.unique_name() {
474 b = b.sender(sender)?;
475 }
476 let m = b.build(body)?;
477 self.send(&m).await
478 }
479
480 pub async fn reply_dbus_error(
485 &self,
486 call: &zbus::message::Header<'_>,
487 err: impl DBusError,
488 ) -> Result<()> {
489 let _permit = acquire_serial_num_semaphore().await;
490
491 let m = err.create_reply(call)?;
492 self.send(&m).await
493 }
494
495 pub async fn request_name<'w, W>(&self, well_known_name: W) -> Result<()>
525 where
526 W: TryInto<WellKnownName<'w>>,
527 W::Error: Into<Error>,
528 {
529 self.request_name_with_flags(well_known_name, BitFlags::default())
530 .await
531 .map(|_| ())
532 }
533
534 pub async fn request_name_with_flags<'w, W>(
608 &self,
609 well_known_name: W,
610 flags: BitFlags<RequestNameFlags>,
611 ) -> Result<RequestNameReply>
612 where
613 W: TryInto<WellKnownName<'w>>,
614 W::Error: Into<Error>,
615 {
616 let well_known_name = well_known_name.try_into().map_err(Into::into)?;
617 let mut names = self.inner.registered_names.lock().await;
620
621 match names.get(&well_known_name) {
622 Some(NameStatus::Owner(_)) => return Ok(RequestNameReply::AlreadyOwner),
623 Some(NameStatus::Queued(_)) => return Ok(RequestNameReply::InQueue),
624 None => (),
625 }
626
627 if !self.is_bus() {
628 names.insert(well_known_name.to_owned(), NameStatus::Owner(None));
629
630 return Ok(RequestNameReply::PrimaryOwner);
631 }
632
633 let acquired_match_rule = MatchRule::fdo_signal_builder("NameAcquired")
634 .arg(0, well_known_name.as_ref())
635 .unwrap()
636 .build();
637 let mut acquired_stream = self.add_match(acquired_match_rule.into(), None).await?;
638 let lost_match_rule = MatchRule::fdo_signal_builder("NameLost")
639 .arg(0, well_known_name.as_ref())
640 .unwrap()
641 .build();
642 let mut lost_stream = self.add_match(lost_match_rule.into(), None).await?;
643 let reply = self
644 .call_method(
645 Some("org.freedesktop.DBus"),
646 "/org/freedesktop/DBus",
647 Some("org.freedesktop.DBus"),
648 "RequestName",
649 &(well_known_name.clone(), flags),
650 )
651 .await?
652 .body()
653 .deserialize::<RequestNameReply>()?;
654 let lost_task_name = format!("monitor name {well_known_name} lost");
655 let name_lost_fut = if flags.contains(RequestNameFlags::AllowReplacement) {
656 let weak_conn = WeakConnection::from(self);
657 let well_known_name = well_known_name.to_owned();
658 Some(
659 async move {
660 loop {
661 let signal = lost_stream.next().await;
662 let inner = match weak_conn.upgrade() {
663 Some(conn) => conn.inner.clone(),
664 None => break,
665 };
666
667 match signal {
668 Some(signal) => match signal {
669 Ok(_) => {
670 tracing::info!(
671 "Connection `{}` lost name `{}`",
672 inner.unique_name.get().unwrap(),
675 well_known_name
676 );
677 inner.registered_names.lock().await.remove(&well_known_name);
678
679 break;
680 }
681 Err(e) => warn!("Failed to parse `NameLost` signal: {}", e),
682 },
683 None => {
684 trace!("`NameLost` signal stream closed");
685 break;
694 }
695 }
696 }
697 }
698 .instrument(info_span!("{}", lost_task_name)),
699 )
700 } else {
701 None
702 };
703 let status = match reply {
704 RequestNameReply::InQueue => {
705 let weak_conn = WeakConnection::from(self);
706 let well_known_name = well_known_name.to_owned();
707 let task_name = format!("monitor name {well_known_name} acquired");
708 let task = self.executor().spawn(
709 async move {
710 loop {
711 let signal = acquired_stream.next().await;
712 let inner = match weak_conn.upgrade() {
713 Some(conn) => conn.inner.clone(),
714 None => break,
715 };
716 match signal {
717 Some(signal) => match signal {
718 Ok(_) => {
719 let mut names = inner.registered_names.lock().await;
720 if let Some(status) = names.get_mut(&well_known_name) {
721 let task = name_lost_fut.map(|fut| {
722 inner.executor.spawn(fut, &lost_task_name)
723 });
724 *status = NameStatus::Owner(task);
725
726 break;
727 }
728 }
730 Err(e) => warn!("Failed to parse `NameAcquired` signal: {}", e),
731 },
732 None => {
733 trace!("`NameAcquired` signal stream closed");
734 break;
737 }
738 }
739 }
740 }
741 .instrument(info_span!("{}", task_name)),
742 &task_name,
743 );
744
745 NameStatus::Queued(task)
746 }
747 RequestNameReply::PrimaryOwner | RequestNameReply::AlreadyOwner => {
748 let task = name_lost_fut.map(|fut| self.executor().spawn(fut, &lost_task_name));
749
750 NameStatus::Owner(task)
751 }
752 RequestNameReply::Exists => return Err(Error::NameTaken),
753 };
754
755 names.insert(well_known_name.to_owned(), status);
756
757 Ok(reply)
758 }
759
760 pub async fn release_name<'w, W>(&self, well_known_name: W) -> Result<bool>
769 where
770 W: TryInto<WellKnownName<'w>>,
771 W::Error: Into<Error>,
772 {
773 let well_known_name: WellKnownName<'w> = well_known_name.try_into().map_err(Into::into)?;
774 let mut names = self.inner.registered_names.lock().await;
775 if names.remove(&well_known_name.to_owned()).is_none() {
777 return Ok(false);
778 };
779
780 if !self.is_bus() {
781 return Ok(true);
782 }
783
784 self.call_method(
785 Some("org.freedesktop.DBus"),
786 "/org/freedesktop/DBus",
787 Some("org.freedesktop.DBus"),
788 "ReleaseName",
789 &well_known_name,
790 )
791 .await?
792 .body()
793 .deserialize::<ReleaseNameReply>()
794 .map(|r| r == ReleaseNameReply::Released)
795 }
796
797 pub fn is_bus(&self) -> bool {
802 #[cfg(feature = "p2p")]
803 {
804 self.inner.bus_conn
805 }
806 #[cfg(not(feature = "p2p"))]
807 {
808 true
809 }
810 }
811
812 pub fn unique_name(&self) -> Option<&OwnedUniqueName> {
817 self.inner.unique_name.get()
818 }
819
820 #[cfg(feature = "bus-impl")]
830 pub fn set_unique_name<U>(&self, unique_name: U) -> Result<()>
831 where
832 U: TryInto<OwnedUniqueName>,
833 U::Error: Into<Error>,
834 {
835 let name = unique_name.try_into().map_err(Into::into)?;
836 self.set_unique_name_(name);
837
838 Ok(())
839 }
840
841 pub fn max_queued(&self) -> usize {
843 self.inner.msg_receiver.capacity()
844 }
845
846 pub fn set_max_queued(&mut self, max: usize) {
848 self.inner.msg_receiver.clone().set_capacity(max);
849 }
850
851 pub fn server_guid(&self) -> &OwnedGuid {
853 &self.inner.server_guid
854 }
855
856 pub fn executor(&self) -> &Executor<'static> {
908 &self.inner.executor
909 }
910
911 pub fn object_server(&self) -> &ObjectServer {
919 self.ensure_object_server(true)
920 }
921
922 pub(crate) fn ensure_object_server(&self, start: bool) -> &ObjectServer {
923 self.inner
924 .object_server
925 .get_or_init(move || self.setup_object_server(start, None))
926 }
927
928 fn setup_object_server(&self, start: bool, started_event: Option<Event>) -> ObjectServer {
929 if start {
930 self.start_object_server(started_event);
931 }
932
933 ObjectServer::new(self)
934 }
935
936 #[instrument(skip(self))]
937 pub(crate) fn start_object_server(&self, started_event: Option<Event>) {
938 self.inner.object_server_dispatch_task.get_or_init(|| {
939 trace!("starting ObjectServer task");
940 let weak_conn = WeakConnection::from(self);
941
942 let obj_server_task_name = "ObjectServer task";
943 self.inner.executor.spawn(
944 async move {
945 let mut stream = match weak_conn.upgrade() {
946 Some(conn) => {
947 let mut builder = MatchRule::builder().msg_type(Type::MethodCall);
948 if let Some(unique_name) = conn.unique_name() {
949 builder = builder.destination(&**unique_name).expect("unique name");
950 }
951 let rule = builder.build();
952 match conn.add_match(rule.into(), None).await {
953 Ok(stream) => stream,
954 Err(e) => {
955 debug!("Failed to create message stream: {}", e);
957
958 return;
959 }
960 }
961 }
962 None => {
963 trace!("Connection is gone, stopping associated object server task");
964
965 return;
966 }
967 };
968 if let Some(started_event) = started_event {
969 started_event.notify(1);
970 }
971
972 trace!("waiting for incoming method call messages..");
973 while let Some(msg) = stream.next().await.and_then(|m| {
974 if let Err(e) = &m {
975 debug!("Error while reading from object server stream: {:?}", e);
976 }
977 m.ok()
978 }) {
979 if let Some(conn) = weak_conn.upgrade() {
980 let hdr = msg.header();
981 if !conn.is_bus() {
984 match hdr.destination() {
985 Some(BusName::Unique(_)) | None => (),
987 Some(BusName::WellKnown(dest)) => {
988 let names = conn.inner.registered_names.lock().await;
989 if !names.is_empty() && !names.contains_key(dest) {
993 trace!(
994 "Got a method call for a different destination: {}",
995 dest
996 );
997
998 continue;
999 }
1000 }
1001 }
1002 }
1003 let server = conn.object_server();
1004 if let Err(e) = server.dispatch_call(&msg, &hdr).await {
1005 debug!(
1006 "Error dispatching message. Message: {:?}, error: {:?}",
1007 msg, e
1008 );
1009 }
1010 } else {
1011 trace!("Connection is gone, stopping associated object server task");
1014 break;
1015 }
1016 }
1017 }
1018 .instrument(info_span!("{}", obj_server_task_name)),
1019 obj_server_task_name,
1020 )
1021 });
1022 }
1023
1024 pub(crate) async fn add_match(
1025 &self,
1026 rule: OwnedMatchRule,
1027 max_queued: Option<usize>,
1028 ) -> Result<Receiver<Result<Message>>> {
1029 use std::collections::hash_map::Entry;
1030
1031 if self.inner.msg_senders.lock().await.is_empty() {
1032 return Err(Error::InputOutput(Arc::new(io::Error::new(
1034 io::ErrorKind::BrokenPipe,
1035 "Socket reader task has errored out",
1036 ))));
1037 }
1038
1039 let mut subscriptions = self.inner.subscriptions.lock().await;
1040 let msg_type = rule.msg_type().unwrap_or(Type::Signal);
1041 match subscriptions.entry(rule.clone()) {
1042 Entry::Vacant(e) => {
1043 let max_queued = max_queued.unwrap_or(DEFAULT_MAX_QUEUED);
1044 let (sender, mut receiver) = broadcast(max_queued);
1045 receiver.set_await_active(false);
1046 if self.is_bus() && msg_type == Type::Signal {
1047 self.call_method(
1048 Some("org.freedesktop.DBus"),
1049 "/org/freedesktop/DBus",
1050 Some("org.freedesktop.DBus"),
1051 "AddMatch",
1052 &e.key(),
1053 )
1054 .await?;
1055 }
1056 e.insert((1, receiver.clone().deactivate()));
1057 self.inner
1058 .msg_senders
1059 .lock()
1060 .await
1061 .insert(Some(rule), sender);
1062
1063 Ok(receiver)
1064 }
1065 Entry::Occupied(mut e) => {
1066 let (num_subscriptions, receiver) = e.get_mut();
1067 *num_subscriptions += 1;
1068 if let Some(max_queued) = max_queued {
1069 if max_queued > receiver.capacity() {
1070 receiver.set_capacity(max_queued);
1071 }
1072 }
1073
1074 Ok(receiver.activate_cloned())
1075 }
1076 }
1077 }
1078
1079 pub(crate) async fn remove_match(&self, rule: OwnedMatchRule) -> Result<bool> {
1080 use std::collections::hash_map::Entry;
1081 let mut subscriptions = self.inner.subscriptions.lock().await;
1082 let msg_type = rule.msg_type().unwrap_or(Type::Signal);
1085 match subscriptions.entry(rule) {
1086 Entry::Vacant(_) => Ok(false),
1087 Entry::Occupied(mut e) => {
1088 let rule = e.key().inner().clone();
1089 e.get_mut().0 -= 1;
1090 if e.get().0 == 0 {
1091 if self.is_bus() && msg_type == Type::Signal {
1092 self.call_method(
1093 Some("org.freedesktop.DBus"),
1094 "/org/freedesktop/DBus",
1095 Some("org.freedesktop.DBus"),
1096 "RemoveMatch",
1097 &rule,
1098 )
1099 .await?;
1100 }
1101 e.remove();
1102 self.inner
1103 .msg_senders
1104 .lock()
1105 .await
1106 .remove(&Some(rule.into()));
1107 }
1108 Ok(true)
1109 }
1110 }
1111 }
1112
1113 pub(crate) fn queue_remove_match(&self, rule: OwnedMatchRule) {
1114 let conn = self.clone();
1115 let task_name = format!("Remove match `{}`", *rule);
1116 let remove_match =
1117 async move { conn.remove_match(rule).await }.instrument(trace_span!("{}", task_name));
1118 self.inner.executor.spawn(remove_match, &task_name).detach()
1119 }
1120
1121 pub(crate) async fn new(
1122 auth: Authenticated,
1123 #[allow(unused)] bus_connection: bool,
1124 executor: Executor<'static>,
1125 ) -> Result<Self> {
1126 #[cfg(unix)]
1127 let cap_unix_fd = auth.cap_unix_fd;
1128
1129 macro_rules! create_msg_broadcast_channel {
1130 ($size:expr) => {{
1131 let (msg_sender, msg_receiver) = broadcast($size);
1132 let mut msg_receiver = msg_receiver.deactivate();
1133 msg_receiver.set_await_active(false);
1134
1135 (msg_sender, msg_receiver)
1136 }};
1137 }
1138 let (msg_sender, msg_receiver) = create_msg_broadcast_channel!(DEFAULT_MAX_QUEUED);
1140 let mut msg_senders = HashMap::new();
1141 msg_senders.insert(None, msg_sender);
1142
1143 let (method_return_sender, method_return_receiver) =
1145 create_msg_broadcast_channel!(DEFAULT_MAX_METHOD_RETURN_QUEUED);
1146 let rule = MatchRule::builder()
1147 .msg_type(Type::MethodReturn)
1148 .build()
1149 .into();
1150 msg_senders.insert(Some(rule), method_return_sender.clone());
1151 let rule = MatchRule::builder().msg_type(Type::Error).build().into();
1152 msg_senders.insert(Some(rule), method_return_sender);
1153 let msg_senders = Arc::new(Mutex::new(msg_senders));
1154 let subscriptions = Mutex::new(HashMap::new());
1155
1156 let connection = Self {
1157 inner: Arc::new(ConnectionInner {
1158 activity_event: Arc::new(Event::new()),
1159 socket_write: Mutex::new(auth.socket_write),
1160 server_guid: auth.server_guid,
1161 #[cfg(unix)]
1162 cap_unix_fd,
1163 #[cfg(feature = "p2p")]
1164 bus_conn: bus_connection,
1165 unique_name: OnceLock::new(),
1166 subscriptions,
1167 object_server: OnceLock::new(),
1168 object_server_dispatch_task: OnceLock::new(),
1169 executor,
1170 socket_reader_task: OnceLock::new(),
1171 msg_senders,
1172 msg_receiver,
1173 method_return_receiver,
1174 registered_names: Mutex::new(HashMap::new()),
1175 drop_event: Event::new(),
1176 }),
1177 };
1178
1179 if let Some(unique_name) = auth.unique_name {
1180 connection.set_unique_name_(unique_name);
1181 }
1182
1183 Ok(connection)
1184 }
1185
1186 pub async fn session() -> Result<Self> {
1188 Builder::session()?.build().await
1189 }
1190
1191 pub async fn system() -> Result<Self> {
1193 Builder::system()?.build().await
1194 }
1195
1196 pub fn monitor_activity(&self) -> EventListener {
1200 self.inner.activity_event.listen()
1201 }
1202
1203 pub async fn peer_credentials(&self) -> io::Result<ConnectionCredentials> {
1212 self.inner
1213 .socket_write
1214 .lock()
1215 .await
1216 .peer_credentials()
1217 .await
1218 }
1219
1220 pub async fn close(self) -> Result<()> {
1224 self.inner.activity_event.notify(usize::MAX);
1225 self.inner
1226 .socket_write
1227 .lock()
1228 .await
1229 .close()
1230 .await
1231 .map_err(Into::into)
1232 }
1233
1234 pub async fn graceful_shutdown(self) {
1278 let listener = self.inner.drop_event.listen();
1279 drop(self);
1280 listener.await;
1281 }
1282
1283 pub(crate) fn init_socket_reader(
1284 &self,
1285 socket_read: Box<dyn socket::ReadHalf>,
1286 already_read: Vec<u8>,
1287 #[cfg(unix)] already_received_fds: Vec<std::os::fd::OwnedFd>,
1288 ) {
1289 let inner = &self.inner;
1290 inner
1291 .socket_reader_task
1292 .set(
1293 SocketReader::new(
1294 socket_read,
1295 inner.msg_senders.clone(),
1296 already_read,
1297 #[cfg(unix)]
1298 already_received_fds,
1299 inner.activity_event.clone(),
1300 )
1301 .spawn(&inner.executor),
1302 )
1303 .expect("Attempted to set `socket_reader_task` twice");
1304 }
1305
1306 fn set_unique_name_(&self, name: OwnedUniqueName) {
1307 self.inner
1308 .unique_name
1309 .set(name)
1310 .expect("unique name already set");
1312 }
1313}
1314
1315#[cfg(feature = "blocking-api")]
1316impl From<crate::blocking::Connection> for Connection {
1317 fn from(conn: crate::blocking::Connection) -> Self {
1318 conn.into_inner()
1319 }
1320}
1321
1322#[derive(Debug, Clone)]
1324pub(crate) struct WeakConnection {
1325 inner: Weak<ConnectionInner>,
1326}
1327
1328impl WeakConnection {
1329 pub fn upgrade(&self) -> Option<Connection> {
1331 self.inner.upgrade().map(|inner| Connection { inner })
1332 }
1333}
1334
1335impl From<&Connection> for WeakConnection {
1336 fn from(conn: &Connection) -> Self {
1337 Self {
1338 inner: Arc::downgrade(&conn.inner),
1339 }
1340 }
1341}
1342
1343#[derive(Debug)]
1344enum NameStatus {
1345 Owner(#[allow(unused)] Option<Task<()>>),
1347 Queued(#[allow(unused)] Task<()>),
1349}
1350
1351static SERIAL_NUM_SEMAPHORE: Semaphore = Semaphore::new(1);
1352
1353async fn acquire_serial_num_semaphore() -> Option<SemaphorePermit<'static>> {
1358 if is_flatpak() {
1359 Some(SERIAL_NUM_SEMAPHORE.acquire().await)
1360 } else {
1361 None
1362 }
1363}
1364
1365#[cfg(test)]
1366mod tests {
1367 use super::*;
1368 use crate::fdo::DBusProxy;
1369 use ntest::timeout;
1370 use std::{pin::pin, time::Duration};
1371 use test_log::test;
1372
1373 #[cfg(windows)]
1374 #[test]
1375 fn connect_autolaunch_session_bus() {
1376 let addr =
1377 crate::win32::autolaunch_bus_address().expect("Unable to get session bus address");
1378
1379 crate::block_on(async { addr.connect().await }).expect("Unable to connect to session bus");
1380 }
1381
1382 #[cfg(target_os = "macos")]
1383 #[test]
1384 fn connect_launchd_session_bus() {
1385 use crate::address::{transport::Launchd, Address, Transport};
1386 crate::block_on(async {
1387 let addr = Address::from(Transport::Launchd(Launchd::new(
1388 "DBUS_LAUNCHD_SESSION_BUS_SOCKET",
1389 )));
1390 addr.connect().await
1391 })
1392 .expect("Unable to connect to session bus");
1393 }
1394
1395 #[test]
1396 #[timeout(15000)]
1397 fn disconnect_on_drop() {
1398 crate::utils::block_on(test_disconnect_on_drop());
1401 }
1402
1403 async fn test_disconnect_on_drop() {
1404 #[derive(Default)]
1405 struct MyInterface {}
1406
1407 #[crate::interface(name = "dev.peelz.FooBar.Baz")]
1408 impl MyInterface {
1409 fn do_thing(&self) {}
1410 }
1411 let name = "dev.peelz.foobar";
1412 let connection = Builder::session()
1413 .unwrap()
1414 .name(name)
1415 .unwrap()
1416 .serve_at("/dev/peelz/FooBar", MyInterface::default())
1417 .unwrap()
1418 .build()
1419 .await
1420 .unwrap();
1421
1422 let connection2 = Connection::session().await.unwrap();
1423 let dbus = DBusProxy::new(&connection2).await.unwrap();
1424 let mut stream = dbus
1425 .receive_name_owner_changed_with_args(&[(0, name), (2, "")])
1426 .await
1427 .unwrap();
1428
1429 drop(connection);
1430
1431 stream.next().await.unwrap();
1433
1434 let name_has_owner = dbus.name_has_owner(name.try_into().unwrap()).await.unwrap();
1436 assert!(!name_has_owner);
1437 }
1438
1439 #[tokio::test(start_paused = true)]
1440 #[timeout(15000)]
1441 async fn test_graceful_shutdown() {
1442 let connection = Connection::session().await.unwrap();
1444 let clone = connection.clone();
1445 let mut shutdown = pin!(connection.graceful_shutdown());
1446 tokio::select! {
1449 _ = tokio::time::sleep(Duration::from_secs(u64::MAX)) => {},
1450 _ = &mut shutdown => {
1451 panic!("Graceful shutdown unexpectedly completed");
1452 }
1453 }
1454
1455 drop(clone);
1456 shutdown.await;
1457
1458 struct GracefulInterface {
1460 method_called: Event,
1461 wait_before_return: Option<EventListener>,
1462 announce_done: Event,
1463 }
1464
1465 #[crate::interface(name = "dev.peelz.TestGracefulShutdown")]
1466 impl GracefulInterface {
1467 async fn do_thing(&mut self) {
1468 self.method_called.notify(1);
1469 if let Some(listener) = self.wait_before_return.take() {
1470 listener.await;
1471 }
1472 self.announce_done.notify(1);
1473 }
1474 }
1475
1476 let method_called = Event::new();
1477 let method_called_listener = method_called.listen();
1478
1479 let trigger_return = Event::new();
1480 let wait_before_return = Some(trigger_return.listen());
1481
1482 let announce_done = Event::new();
1483 let done_listener = announce_done.listen();
1484
1485 let interface = GracefulInterface {
1486 method_called,
1487 wait_before_return,
1488 announce_done,
1489 };
1490
1491 let name = "dev.peelz.TestGracefulShutdown";
1492 let obj = "/dev/peelz/TestGracefulShutdown";
1493 let connection = Builder::session()
1494 .unwrap()
1495 .name(name)
1496 .unwrap()
1497 .serve_at(obj, interface)
1498 .unwrap()
1499 .build()
1500 .await
1501 .unwrap();
1502
1503 let client_conn = Connection::session().await.unwrap();
1505 tokio::spawn(async move {
1506 client_conn
1507 .call_method(Some(name), obj, Some(name), "DoThing", &())
1508 .await
1509 .unwrap();
1510 });
1511
1512 method_called_listener.await;
1515
1516 let mut shutdown = pin!(connection.graceful_shutdown());
1517 tokio::select! {
1518 _ = tokio::time::sleep(Duration::from_secs(u64::MAX)) => {},
1519 _ = &mut shutdown => {
1520 panic!("Graceful shutdown unexpectedly completed");
1522 }
1523 }
1524
1525 trigger_return.notify(1);
1527 shutdown.await;
1528
1529 done_listener.await;
1531 }
1532}
1533
1534#[cfg(feature = "p2p")]
1535#[cfg(test)]
1536mod p2p_tests {
1537 use event_listener::Event;
1538 use futures_util::TryStreamExt;
1539 use ntest::timeout;
1540 use test_log::test;
1541 use zvariant::{Endian, NATIVE_ENDIAN};
1542
1543 use super::{socket, Builder, Connection};
1544 use crate::{conn::AuthMechanism, Guid, Message, MessageStream, Result};
1545
1546 async fn test_p2p(
1548 server1: Connection,
1549 client1: Connection,
1550 server2: Connection,
1551 client2: Connection,
1552 ) -> Result<()> {
1553 let forward1 = {
1554 let stream = MessageStream::from(server1.clone());
1555 let sink = client2.clone();
1556
1557 stream.try_for_each(move |msg| {
1558 let sink = sink.clone();
1559 async move { sink.send(&msg).await }
1560 })
1561 };
1562 let forward2 = {
1563 let stream = MessageStream::from(client2.clone());
1564 let sink = server1.clone();
1565
1566 stream.try_for_each(move |msg| {
1567 let sink = sink.clone();
1568 async move { sink.send(&msg).await }
1569 })
1570 };
1571 let _forward_task = client1.executor().spawn(
1572 async move { futures_util::try_join!(forward1, forward2) },
1573 "forward_task",
1574 );
1575
1576 let server_ready = Event::new();
1577 let server_ready_listener = server_ready.listen();
1578 let client_done = Event::new();
1579 let client_done_listener = client_done.listen();
1580
1581 let server_future = async move {
1582 let mut stream = MessageStream::from(&server2);
1583 server_ready.notify(1);
1584 let method = loop {
1585 let m = stream.try_next().await?.unwrap();
1586 if m.to_string() == "Method call Test" {
1587 assert_eq!(m.body().deserialize::<u64>().unwrap(), 64);
1588 break m;
1589 }
1590 };
1591
1592 server2
1594 .emit_signal(None::<()>, "/", "org.zbus.p2p", "ASignalForYou", &())
1595 .await?;
1596 server2.reply(&method.header(), &("yay")).await?;
1597 client_done_listener.await;
1598
1599 Ok(())
1600 };
1601
1602 let client_future = async move {
1603 let mut stream = MessageStream::from(&client1);
1604 server_ready_listener.await;
1605 let endian = match NATIVE_ENDIAN {
1609 Endian::Little => Endian::Big,
1610 Endian::Big => Endian::Little,
1611 };
1612 let method = Message::method_call("/", "Test")?
1613 .interface("org.zbus.p2p")?
1614 .endian(endian)
1615 .build(&64u64)?;
1616 client1.send(&method).await?;
1617 let m = stream.try_next().await?.unwrap();
1619 client_done.notify(1);
1620 assert_eq!(m.to_string(), "Signal ASignalForYou");
1621 let reply = stream.try_next().await?.unwrap();
1622 assert_eq!(reply.to_string(), "Method return");
1623 assert_eq!(Endian::from(reply.primary_header().endian_sig()), endian);
1625 reply.body().deserialize::<String>()
1626 };
1627
1628 let (val, _) = futures_util::try_join!(client_future, server_future,)?;
1629 assert_eq!(val, "yay");
1630
1631 Ok(())
1632 }
1633
1634 #[test]
1635 #[timeout(15000)]
1636 fn tcp_p2p() {
1637 crate::utils::block_on(test_tcp_p2p()).unwrap();
1638 }
1639
1640 async fn test_tcp_p2p() -> Result<()> {
1641 let (server1, client1) = tcp_p2p_pipe().await?;
1642 let (server2, client2) = tcp_p2p_pipe().await?;
1643
1644 test_p2p(server1, client1, server2, client2).await
1645 }
1646
1647 async fn tcp_p2p_pipe() -> Result<(Connection, Connection)> {
1648 let guid = Guid::generate();
1649
1650 #[cfg(not(feature = "tokio"))]
1651 let (server_conn_builder, client_conn_builder) = {
1652 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
1653 let addr = listener.local_addr().unwrap();
1654 let p1 = std::net::TcpStream::connect(addr).unwrap();
1655 let p0 = listener.incoming().next().unwrap().unwrap();
1656
1657 (
1658 Builder::tcp_stream(p0)
1659 .server(guid)
1660 .unwrap()
1661 .p2p()
1662 .auth_mechanism(AuthMechanism::Anonymous),
1663 Builder::tcp_stream(p1).p2p(),
1664 )
1665 };
1666
1667 #[cfg(feature = "tokio")]
1668 let (server_conn_builder, client_conn_builder) = {
1669 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1670 let addr = listener.local_addr().unwrap();
1671 let p1 = tokio::net::TcpStream::connect(addr).await.unwrap();
1672 let p0 = listener.accept().await.unwrap().0;
1673
1674 (
1675 Builder::tcp_stream(p0)
1676 .server(guid)
1677 .unwrap()
1678 .p2p()
1679 .auth_mechanism(AuthMechanism::Anonymous),
1680 Builder::tcp_stream(p1).p2p(),
1681 )
1682 };
1683
1684 futures_util::try_join!(server_conn_builder.build(), client_conn_builder.build())
1685 }
1686
1687 #[cfg(unix)]
1688 #[test]
1689 #[timeout(15000)]
1690 fn unix_p2p() {
1691 crate::utils::block_on(test_unix_p2p()).unwrap();
1692 }
1693
1694 #[cfg(unix)]
1695 async fn test_unix_p2p() -> Result<()> {
1696 let (server1, client1) = unix_p2p_pipe().await?;
1697 let (server2, client2) = unix_p2p_pipe().await?;
1698
1699 test_p2p(server1, client1, server2, client2).await
1700 }
1701
1702 #[cfg(unix)]
1703 async fn unix_p2p_pipe() -> Result<(Connection, Connection)> {
1704 #[cfg(not(feature = "tokio"))]
1705 use std::os::unix::net::UnixStream;
1706 #[cfg(feature = "tokio")]
1707 use tokio::net::UnixStream;
1708 #[cfg(all(windows, not(feature = "tokio")))]
1709 use uds_windows::UnixStream;
1710
1711 let guid = Guid::generate();
1712
1713 let (p0, p1) = UnixStream::pair().unwrap();
1714
1715 futures_util::try_join!(
1716 Builder::unix_stream(p1).p2p().build(),
1717 Builder::unix_stream(p0).server(guid).unwrap().p2p().build(),
1718 )
1719 }
1720
1721 #[cfg(any(
1722 all(feature = "vsock", not(feature = "tokio")),
1723 feature = "tokio-vsock"
1724 ))]
1725 #[test]
1726 #[timeout(15000)]
1727 fn vsock_connect() {
1728 let _ = crate::utils::block_on(test_vsock_connect()).unwrap();
1729 }
1730
1731 #[cfg(any(
1732 all(feature = "vsock", not(feature = "tokio")),
1733 feature = "tokio-vsock"
1734 ))]
1735 async fn test_vsock_connect() -> Result<(Connection, Connection)> {
1736 #[cfg(feature = "tokio-vsock")]
1737 use futures_util::StreamExt;
1738
1739 let guid = Guid::generate();
1740
1741 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1742 let listener = vsock::VsockListener::bind_with_cid_port(vsock::VMADDR_CID_LOCAL, u32::MAX)?;
1743 #[cfg(feature = "tokio-vsock")]
1744 let listener = tokio_vsock::VsockListener::bind(tokio_vsock::VsockAddr::new(1, u32::MAX))?;
1745
1746 let addr = listener.local_addr()?;
1747 let addr = format!("vsock:cid={},port={},guid={guid}", addr.cid(), addr.port());
1748
1749 let server = async {
1750 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1751 let server = crate::Task::spawn_blocking(move || listener.incoming().next(), "").await;
1752 #[cfg(feature = "tokio-vsock")]
1753 let server = listener.incoming().next().await;
1754 Builder::vsock_stream(server.unwrap()?)
1755 .server(guid)?
1756 .p2p()
1757 .auth_mechanism(AuthMechanism::Anonymous)
1758 .build()
1759 .await
1760 };
1761
1762 let client = crate::connection::Builder::address(addr.as_str())?
1763 .p2p()
1764 .build();
1765
1766 futures_util::try_join!(server, client)
1767 }
1768
1769 #[cfg(any(
1770 all(feature = "vsock", not(feature = "tokio")),
1771 feature = "tokio-vsock"
1772 ))]
1773 #[test]
1774 #[timeout(15000)]
1775 fn vsock_p2p() {
1776 crate::utils::block_on(test_vsock_p2p()).unwrap();
1777 }
1778
1779 #[cfg(any(
1780 all(feature = "vsock", not(feature = "tokio")),
1781 feature = "tokio-vsock"
1782 ))]
1783 async fn test_vsock_p2p() -> Result<()> {
1784 let (server1, client1) = vsock_p2p_pipe().await?;
1785 let (server2, client2) = vsock_p2p_pipe().await?;
1786
1787 test_p2p(server1, client1, server2, client2).await
1788 }
1789
1790 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1791 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1792 let guid = Guid::generate();
1793
1794 let listener =
1795 vsock::VsockListener::bind_with_cid_port(vsock::VMADDR_CID_LOCAL, u32::MAX).unwrap();
1796 let addr = listener.local_addr().unwrap();
1797 let client = vsock::VsockStream::connect(&addr).unwrap();
1798 let server = listener.incoming().next().unwrap().unwrap();
1799
1800 futures_util::try_join!(
1801 Builder::vsock_stream(server)
1802 .server(guid)
1803 .unwrap()
1804 .p2p()
1805 .auth_mechanism(AuthMechanism::Anonymous)
1806 .build(),
1807 Builder::vsock_stream(client).p2p().build(),
1808 )
1809 }
1810
1811 #[cfg(feature = "tokio-vsock")]
1812 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1813 use futures_util::StreamExt;
1814 use tokio_vsock::VsockAddr;
1815
1816 let guid = Guid::generate();
1817
1818 let listener = tokio_vsock::VsockListener::bind(VsockAddr::new(1, u32::MAX)).unwrap();
1819 let addr = listener.local_addr().unwrap();
1820 let client = tokio_vsock::VsockStream::connect(addr).await.unwrap();
1821 let server = listener.incoming().next().await.unwrap().unwrap();
1822
1823 futures_util::try_join!(
1824 Builder::vsock_stream(server)
1825 .server(guid)
1826 .unwrap()
1827 .p2p()
1828 .auth_mechanism(AuthMechanism::Anonymous)
1829 .build(),
1830 Builder::vsock_stream(client).p2p().build(),
1831 )
1832 }
1833
1834 #[test]
1835 #[timeout(15000)]
1836 fn channel_pair() {
1837 crate::utils::block_on(test_channel_pair()).unwrap();
1838 }
1839
1840 async fn test_channel_pair() -> Result<()> {
1841 let (server1, client1) = create_channel_pair().await;
1842 let (server2, client2) = create_channel_pair().await;
1843
1844 test_p2p(server1, client1, server2, client2).await
1845 }
1846
1847 async fn create_channel_pair() -> (Connection, Connection) {
1848 let (a, b) = socket::Channel::pair();
1849
1850 let guid = crate::Guid::generate();
1851 let conn1 = Builder::authenticated_socket(a, guid.clone())
1852 .unwrap()
1853 .p2p()
1854 .build()
1855 .await
1856 .unwrap();
1857 let conn2 = Builder::authenticated_socket(b, guid)
1858 .unwrap()
1859 .p2p()
1860 .build()
1861 .await
1862 .unwrap();
1863
1864 (conn1, conn2)
1865 }
1866}