1use async_broadcast::{broadcast, InactiveReceiver, Receiver, Sender as Broadcaster};
2use enumflags2::BitFlags;
3use event_listener::{Event, EventListener};
4use once_cell::sync::OnceCell;
5use ordered_stream::{OrderedFuture, OrderedStream, PollResult};
6use static_assertions::assert_impl_all;
7use std::{
8 collections::HashMap,
9 convert::TryInto,
10 io::{self, ErrorKind},
11 ops::Deref,
12 pin::Pin,
13 sync::{
14 self,
15 atomic::{AtomicU32, Ordering::SeqCst},
16 Arc, Weak,
17 },
18 task::{Context, Poll},
19};
20use tracing::{debug, info_span, instrument, trace, trace_span, warn, Instrument};
21use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, OwnedUniqueName, WellKnownName};
22use zvariant::ObjectPath;
23
24use futures_core::{ready, Future};
25use futures_sink::Sink;
26use futures_util::{sink::SinkExt, StreamExt};
27
28use crate::{
29 async_lock::Mutex,
30 blocking,
31 fdo::{self, ConnectionCredentials, RequestNameFlags, RequestNameReply},
32 raw::{Connection as RawConnection, Socket},
33 socket_reader::SocketReader,
34 Authenticated, CacheProperties, ConnectionBuilder, DBusError, Error, Executor, Guid, MatchRule,
35 Message, MessageBuilder, MessageFlags, MessageStream, MessageType, ObjectServer,
36 OwnedMatchRule, Result, Task,
37};
38
39const DEFAULT_MAX_QUEUED: usize = 64;
40const DEFAULT_MAX_METHOD_RETURN_QUEUED: usize = 8;
41
42#[derive(Debug)]
44pub(crate) struct ConnectionInner {
45 server_guid: Guid,
46 #[cfg(unix)]
47 cap_unix_fd: bool,
48 bus_conn: bool,
49 unique_name: OnceCell<OwnedUniqueName>,
50 registered_names: Mutex<HashMap<WellKnownName<'static>, NameStatus>>,
51
52 raw_conn: Arc<sync::Mutex<RawConnection<Box<dyn Socket>>>>,
53
54 serial: AtomicU32,
56
57 executor: Executor<'static>,
59
60 #[allow(unused)]
62 socket_reader_task: OnceCell<Task<()>>,
63
64 pub(crate) msg_receiver: InactiveReceiver<Result<Arc<Message>>>,
65 pub(crate) method_return_receiver: InactiveReceiver<Result<Arc<Message>>>,
66 msg_senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
67
68 subscriptions: Mutex<Subscriptions>,
69
70 object_server: OnceCell<blocking::ObjectServer>,
71 object_server_dispatch_task: OnceCell<Task<()>>,
72}
73
74type Subscriptions = HashMap<OwnedMatchRule, (u64, InactiveReceiver<Result<Arc<Message>>>)>;
75
76pub(crate) type MsgBroadcaster = Broadcaster<Result<Arc<Message>>>;
77
78#[derive(Clone, Debug)]
202#[must_use = "Dropping a `Connection` will close the underlying socket."]
203pub struct Connection {
204 pub(crate) inner: Arc<ConnectionInner>,
205}
206
207assert_impl_all!(Connection: Send, Sync, Unpin);
208
209#[derive(Debug)]
215pub(crate) struct PendingMethodCall {
216 stream: Option<MessageStream>,
217 serial: u32,
218}
219
220impl Future for PendingMethodCall {
221 type Output = Result<Arc<Message>>;
222
223 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
224 self.poll_before(cx, None).map(|ret| {
225 ret.map(|(_, r)| r).unwrap_or_else(|| {
226 Err(crate::Error::InputOutput(
227 io::Error::new(ErrorKind::BrokenPipe, "socket closed").into(),
228 ))
229 })
230 })
231 }
232}
233
234impl OrderedFuture for PendingMethodCall {
235 type Output = Result<Arc<Message>>;
236 type Ordering = zbus::MessageSequence;
237
238 fn poll_before(
239 self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 before: Option<&Self::Ordering>,
242 ) -> Poll<Option<(Self::Ordering, Self::Output)>> {
243 let this = self.get_mut();
244 if let Some(stream) = &mut this.stream {
245 loop {
246 match Pin::new(&mut *stream).poll_next_before(cx, before) {
247 Poll::Ready(PollResult::Item {
248 data: Ok(msg),
249 ordering,
250 }) => {
251 if msg.reply_serial() != Some(this.serial) {
252 continue;
253 }
254 let res = match msg.message_type() {
255 MessageType::Error => Err(msg.into()),
256 MessageType::MethodReturn => Ok(msg),
257 _ => continue,
258 };
259 this.stream = None;
260 return Poll::Ready(Some((ordering, res)));
261 }
262 Poll::Ready(PollResult::Item {
263 data: Err(e),
264 ordering,
265 }) => {
266 return Poll::Ready(Some((ordering, Err(e))));
267 }
268
269 Poll::Ready(PollResult::NoneBefore) => {
270 return Poll::Ready(None);
271 }
272 Poll::Ready(PollResult::Terminated) => {
273 return Poll::Ready(None);
274 }
275 Poll::Pending => return Poll::Pending,
276 }
277 }
278 }
279 Poll::Ready(None)
280 }
281}
282
283impl Connection {
284 pub async fn send_message(&self, mut msg: Message) -> Result<u32> {
291 let serial = self.assign_serial_num(&mut msg)?;
292
293 trace!("Sending message: {:?}", msg);
294 (&mut &*self).send(msg).await?;
295 trace!("Sent message with serial: {}", serial);
296
297 Ok(serial)
298 }
299
300 pub async fn call_method<'d, 'p, 'i, 'm, D, P, I, M, B>(
307 &self,
308 destination: Option<D>,
309 path: P,
310 interface: Option<I>,
311 method_name: M,
312 body: &B,
313 ) -> Result<Arc<Message>>
314 where
315 D: TryInto<BusName<'d>>,
316 P: TryInto<ObjectPath<'p>>,
317 I: TryInto<InterfaceName<'i>>,
318 M: TryInto<MemberName<'m>>,
319 D::Error: Into<Error>,
320 P::Error: Into<Error>,
321 I::Error: Into<Error>,
322 M::Error: Into<Error>,
323 B: serde::ser::Serialize + zvariant::DynamicType,
324 {
325 self.call_method_raw(
326 destination,
327 path,
328 interface,
329 method_name,
330 BitFlags::empty(),
331 body,
332 )
333 .await?
334 .expect("no reply")
335 .await
336 }
337
338 pub(crate) async fn call_method_raw<'d, 'p, 'i, 'm, D, P, I, M, B>(
349 &self,
350 destination: Option<D>,
351 path: P,
352 interface: Option<I>,
353 method_name: M,
354 flags: BitFlags<MessageFlags>,
355 body: &B,
356 ) -> Result<Option<PendingMethodCall>>
357 where
358 D: TryInto<BusName<'d>>,
359 P: TryInto<ObjectPath<'p>>,
360 I: TryInto<InterfaceName<'i>>,
361 M: TryInto<MemberName<'m>>,
362 D::Error: Into<Error>,
363 P::Error: Into<Error>,
364 I::Error: Into<Error>,
365 M::Error: Into<Error>,
366 B: serde::ser::Serialize + zvariant::DynamicType,
367 {
368 let mut builder = MessageBuilder::method_call(path, method_name)?;
369 if let Some(sender) = self.unique_name() {
370 builder = builder.sender(sender)?
371 }
372 if let Some(destination) = destination {
373 builder = builder.destination(destination)?
374 }
375 if let Some(interface) = interface {
376 builder = builder.interface(interface)?
377 }
378 for flag in flags {
379 builder = builder.with_flags(flag)?;
380 }
381 let msg = builder.build(body)?;
382
383 let msg_receiver = self.inner.method_return_receiver.activate_cloned();
384 let stream = Some(MessageStream::for_subscription_channel(
385 msg_receiver,
386 None,
388 self,
389 ));
390 let serial = self.send_message(msg).await?;
391 if flags.contains(MessageFlags::NoReplyExpected) {
392 Ok(None)
393 } else {
394 Ok(Some(PendingMethodCall { stream, serial }))
395 }
396 }
397
398 pub async fn emit_signal<'d, 'p, 'i, 'm, D, P, I, M, B>(
402 &self,
403 destination: Option<D>,
404 path: P,
405 interface: I,
406 signal_name: M,
407 body: &B,
408 ) -> Result<()>
409 where
410 D: TryInto<BusName<'d>>,
411 P: TryInto<ObjectPath<'p>>,
412 I: TryInto<InterfaceName<'i>>,
413 M: TryInto<MemberName<'m>>,
414 D::Error: Into<Error>,
415 P::Error: Into<Error>,
416 I::Error: Into<Error>,
417 M::Error: Into<Error>,
418 B: serde::ser::Serialize + zvariant::DynamicType,
419 {
420 let m = Message::signal(
421 self.unique_name(),
422 destination,
423 path,
424 interface,
425 signal_name,
426 body,
427 )?;
428
429 self.send_message(m).await.map(|_| ())
430 }
431
432 pub async fn reply<B>(&self, call: &Message, body: &B) -> Result<u32>
439 where
440 B: serde::ser::Serialize + zvariant::DynamicType,
441 {
442 let m = Message::method_reply(self.unique_name(), call, body)?;
443 self.send_message(m).await
444 }
445
446 pub async fn reply_error<'e, E, B>(
453 &self,
454 call: &Message,
455 error_name: E,
456 body: &B,
457 ) -> Result<u32>
458 where
459 B: serde::ser::Serialize + zvariant::DynamicType,
460 E: TryInto<ErrorName<'e>>,
461 E::Error: Into<Error>,
462 {
463 let m = Message::method_error(self.unique_name(), call, error_name, body)?;
464 self.send_message(m).await
465 }
466
467 pub async fn reply_dbus_error(
474 &self,
475 call: &zbus::MessageHeader<'_>,
476 err: impl DBusError,
477 ) -> Result<u32> {
478 let m = err.create_reply(call);
479 self.send_message(m?).await
480 }
481
482 pub async fn request_name<'w, W>(&self, well_known_name: W) -> Result<()>
512 where
513 W: TryInto<WellKnownName<'w>>,
514 W::Error: Into<Error>,
515 {
516 self.request_name_with_flags(
517 well_known_name,
518 RequestNameFlags::ReplaceExisting | RequestNameFlags::DoNotQueue,
519 )
520 .await
521 .map(|_| ())
522 }
523
524 pub async fn request_name_with_flags<'w, W>(
598 &self,
599 well_known_name: W,
600 flags: BitFlags<RequestNameFlags>,
601 ) -> Result<RequestNameReply>
602 where
603 W: TryInto<WellKnownName<'w>>,
604 W::Error: Into<Error>,
605 {
606 let well_known_name = well_known_name.try_into().map_err(Into::into)?;
607 let mut names = self.inner.registered_names.lock().await;
610
611 match names.get(&well_known_name) {
612 Some(NameStatus::Owner(_)) => return Ok(RequestNameReply::AlreadyOwner),
613 Some(NameStatus::Queued(_)) => return Ok(RequestNameReply::InQueue),
614 None => (),
615 }
616
617 if !self.is_bus() {
618 names.insert(well_known_name.to_owned(), NameStatus::Owner(None));
619
620 return Ok(RequestNameReply::PrimaryOwner);
621 }
622
623 let dbus_proxy = fdo::DBusProxy::builder(self)
624 .cache_properties(CacheProperties::No)
625 .build()
626 .await?;
627 let mut acquired_stream = dbus_proxy.receive_name_acquired().await?;
628 let mut lost_stream = dbus_proxy.receive_name_lost().await?;
629 let reply = dbus_proxy
630 .request_name(well_known_name.clone(), flags)
631 .await?;
632 let lost_task_name = format!("monitor name {well_known_name} lost");
633 let name_lost_fut = if flags.contains(RequestNameFlags::AllowReplacement) {
634 let weak_conn = WeakConnection::from(self);
635 let well_known_name = well_known_name.to_owned();
636 Some(
637 async move {
638 loop {
639 let signal = lost_stream.next().await;
640 let inner = match weak_conn.upgrade() {
641 Some(conn) => conn.inner.clone(),
642 None => break,
643 };
644
645 match signal {
646 Some(signal) => match signal.args() {
647 Ok(args) if args.name == well_known_name => {
648 tracing::info!(
649 "Connection `{}` lost name `{}`",
650 inner.unique_name.get().unwrap(),
653 well_known_name
654 );
655 inner.registered_names.lock().await.remove(&well_known_name);
656
657 break;
658 }
659 Ok(_) => (),
660 Err(e) => warn!("Failed to parse `NameLost` signal: {}", e),
661 },
662 None => {
663 trace!("`NameLost` signal stream closed");
664 break;
673 }
674 }
675 }
676 }
677 .instrument(info_span!("{}", lost_task_name)),
678 )
679 } else {
680 None
681 };
682 let status = match reply {
683 RequestNameReply::InQueue => {
684 let weak_conn = WeakConnection::from(self);
685 let well_known_name = well_known_name.to_owned();
686 let task_name = format!("monitor name {well_known_name} acquired");
687 let task = self.executor().spawn(
688 async move {
689 loop {
690 let signal = acquired_stream.next().await;
691 let inner = match weak_conn.upgrade() {
692 Some(conn) => conn.inner.clone(),
693 None => break,
694 };
695 match signal {
696 Some(signal) => match signal.args() {
697 Ok(args) if args.name == well_known_name => {
698 let mut names = inner.registered_names.lock().await;
699 if let Some(status) = names.get_mut(&well_known_name) {
700 let task = name_lost_fut.map(|fut| {
701 inner.executor.spawn(fut, &lost_task_name)
702 });
703 *status = NameStatus::Owner(task);
704
705 break;
706 }
707 }
709 Ok(_) => (),
710 Err(e) => warn!("Failed to parse `NameAcquired` signal: {}", e),
711 },
712 None => {
713 trace!("`NameAcquired` signal stream closed");
714 break;
717 }
718 }
719 }
720 }
721 .instrument(info_span!("{}", task_name)),
722 &task_name,
723 );
724
725 NameStatus::Queued(task)
726 }
727 RequestNameReply::PrimaryOwner | RequestNameReply::AlreadyOwner => {
728 let task = name_lost_fut.map(|fut| self.executor().spawn(fut, &lost_task_name));
729
730 NameStatus::Owner(task)
731 }
732 RequestNameReply::Exists => return Err(Error::NameTaken),
733 };
734
735 names.insert(well_known_name.to_owned(), status);
736
737 Ok(reply)
738 }
739
740 pub async fn release_name<'w, W>(&self, well_known_name: W) -> Result<bool>
749 where
750 W: TryInto<WellKnownName<'w>>,
751 W::Error: Into<Error>,
752 {
753 let well_known_name: WellKnownName<'w> = well_known_name.try_into().map_err(Into::into)?;
754 let mut names = self.inner.registered_names.lock().await;
755 if names.remove(&well_known_name.to_owned()).is_none() {
757 return Ok(false);
758 };
759
760 if !self.is_bus() {
761 return Ok(true);
762 }
763
764 fdo::DBusProxy::builder(self)
765 .cache_properties(CacheProperties::No)
766 .build()
767 .await?
768 .release_name(well_known_name)
769 .await
770 .map(|_| true)
771 .map_err(Into::into)
772 }
773
774 pub fn is_bus(&self) -> bool {
778 self.inner.bus_conn
779 }
780
781 pub fn assign_serial_num(&self, msg: &mut Message) -> Result<u32> {
785 let mut serial = 0;
786 msg.modify_primary_header(|primary| {
787 serial = *primary.serial_num_or_init(|| self.next_serial());
788 Ok(())
789 })?;
790
791 Ok(serial)
792 }
793
794 pub fn unique_name(&self) -> Option<&OwnedUniqueName> {
799 self.inner.unique_name.get()
800 }
801
802 pub fn set_unique_name<U>(&self, unique_name: U) -> Result<()>
810 where
811 U: TryInto<OwnedUniqueName>,
812 U::Error: Into<Error>,
813 {
814 let name = unique_name.try_into().map_err(Into::into)?;
815 self.inner
816 .unique_name
817 .set(name)
818 .expect("unique name already set");
819
820 Ok(())
821 }
822
823 pub fn max_queued(&self) -> usize {
825 self.inner.msg_receiver.capacity()
826 }
827
828 pub fn set_max_queued(&mut self, max: usize) {
830 self.inner.msg_receiver.clone().set_capacity(max);
831 }
832
833 pub fn server_guid(&self) -> &str {
835 self.inner.server_guid.as_str()
836 }
837
838 pub fn executor(&self) -> &Executor<'static> {
895 &self.inner.executor
896 }
897
898 pub fn object_server(&self) -> impl Deref<Target = ObjectServer> + '_ {
906 struct Wrapper<'a>(&'a blocking::ObjectServer);
909 impl<'a> Deref for Wrapper<'a> {
910 type Target = ObjectServer;
911
912 fn deref(&self) -> &Self::Target {
913 self.0.inner()
914 }
915 }
916
917 Wrapper(self.sync_object_server(true, None))
918 }
919
920 pub(crate) fn sync_object_server(
921 &self,
922 start: bool,
923 started_event: Option<Event>,
924 ) -> &blocking::ObjectServer {
925 self.inner
926 .object_server
927 .get_or_init(move || self.setup_object_server(start, started_event))
928 }
929
930 fn setup_object_server(
931 &self,
932 start: bool,
933 started_event: Option<Event>,
934 ) -> blocking::ObjectServer {
935 if start {
936 self.start_object_server(started_event);
937 }
938
939 blocking::ObjectServer::new(self)
940 }
941
942 #[instrument(skip(self))]
943 pub(crate) fn start_object_server(&self, started_event: Option<Event>) {
944 self.inner.object_server_dispatch_task.get_or_init(|| {
945 trace!("starting ObjectServer task");
946 let weak_conn = WeakConnection::from(self);
947
948 let obj_server_task_name = "ObjectServer task";
949 self.inner.executor.spawn(
950 async move {
951 let mut stream = match weak_conn.upgrade() {
952 Some(conn) => {
953 let mut builder = MatchRule::builder().msg_type(MessageType::MethodCall);
954 if let Some(unique_name) = conn.unique_name() {
955 builder = builder.destination(&**unique_name).expect("unique name");
956 }
957 let rule = builder.build();
958 match conn.add_match(rule.into(), None).await {
959 Ok(stream) => stream,
960 Err(e) => {
961 debug!("Failed to create message stream: {}", e);
963
964 return;
965 }
966 }
967 }
968 None => {
969 trace!("Connection is gone, stopping associated object server task");
970
971 return;
972 }
973 };
974 if let Some(started_event) = started_event {
975 started_event.notify(1);
976 }
977
978 trace!("waiting for incoming method call messages..");
979 while let Some(msg) = stream.next().await.and_then(|m| {
980 if let Err(e) = &m {
981 debug!("Error while reading from object server stream: {:?}", e);
982 }
983 m.ok()
984 }) {
985 if let Some(conn) = weak_conn.upgrade() {
986 let hdr = match msg.header() {
987 Ok(hdr) => hdr,
988 Err(e) => {
989 warn!("Failed to parse header: {}", e);
990
991 continue;
992 }
993 };
994 match hdr.destination() {
995 Ok(Some(BusName::Unique(_))) | Ok(None) => (),
997 Ok(Some(BusName::WellKnown(dest))) => {
998 let names = conn.inner.registered_names.lock().await;
999 if !names.is_empty() && !names.contains_key(dest) {
1002 trace!("Got a method call for a different destination: {}", dest);
1003
1004 continue;
1005 }
1006 }
1007 Err(e) => {
1008 warn!("Failed to parse destination: {}", e);
1009
1010 continue;
1011 }
1012 }
1013 let member = match msg.member() {
1014 Some(member) => member,
1015 None => {
1016 warn!("Got a method call with no `MEMBER` field: {}", msg);
1017
1018 continue;
1019 }
1020 };
1021 trace!("Got `{}`. Will spawn a task for dispatch..", msg);
1022 let executor = conn.inner.executor.clone();
1023 let task_name = format!("`{member}` method dispatcher");
1024 executor
1025 .spawn(
1026 async move {
1027 trace!("spawned a task to dispatch `{}`.", msg);
1028 let server = conn.object_server();
1029 if let Err(e) = server.dispatch_message(&msg).await {
1030 debug!(
1031 "Error dispatching message. Message: {:?}, error: {:?}",
1032 msg, e
1033 );
1034 }
1035 }
1036 .instrument(trace_span!("{}", task_name)),
1037 &task_name,
1038 )
1039 .detach();
1040 } else {
1041 trace!("Connection is gone, stopping associated object server task");
1043 break;
1044 }
1045 }
1046 }
1047 .instrument(info_span!("{}", obj_server_task_name)),
1048 obj_server_task_name,
1049 )
1050 });
1051 }
1052
1053 pub(crate) async fn add_match(
1054 &self,
1055 rule: OwnedMatchRule,
1056 max_queued: Option<usize>,
1057 ) -> Result<Receiver<Result<Arc<Message>>>> {
1058 use std::collections::hash_map::Entry;
1059
1060 if self.inner.msg_senders.lock().await.is_empty() {
1061 return Err(Error::InputOutput(Arc::new(io::Error::new(
1063 io::ErrorKind::BrokenPipe,
1064 "Socket reader task has errored out",
1065 ))));
1066 }
1067
1068 let mut subscriptions = self.inner.subscriptions.lock().await;
1069 let msg_type = rule.msg_type().unwrap_or(MessageType::Signal);
1070 match subscriptions.entry(rule.clone()) {
1071 Entry::Vacant(e) => {
1072 let max_queued = max_queued.unwrap_or(DEFAULT_MAX_QUEUED);
1073 let (sender, mut receiver) = broadcast(max_queued);
1074 receiver.set_await_active(false);
1075 if self.is_bus() && msg_type == MessageType::Signal {
1076 fdo::DBusProxy::builder(self)
1077 .cache_properties(CacheProperties::No)
1078 .build()
1079 .await?
1080 .add_match_rule(e.key().inner().clone())
1081 .await?;
1082 }
1083 e.insert((1, receiver.clone().deactivate()));
1084 self.inner
1085 .msg_senders
1086 .lock()
1087 .await
1088 .insert(Some(rule), sender);
1089
1090 Ok(receiver)
1091 }
1092 Entry::Occupied(mut e) => {
1093 let (num_subscriptions, receiver) = e.get_mut();
1094 *num_subscriptions += 1;
1095 if let Some(max_queued) = max_queued {
1096 if max_queued > receiver.capacity() {
1097 receiver.set_capacity(max_queued);
1098 }
1099 }
1100
1101 Ok(receiver.activate_cloned())
1102 }
1103 }
1104 }
1105
1106 pub(crate) async fn remove_match(&self, rule: OwnedMatchRule) -> Result<bool> {
1107 use std::collections::hash_map::Entry;
1108 let mut subscriptions = self.inner.subscriptions.lock().await;
1109 let msg_type = rule.msg_type().unwrap_or(MessageType::Signal);
1112 match subscriptions.entry(rule) {
1113 Entry::Vacant(_) => Ok(false),
1114 Entry::Occupied(mut e) => {
1115 let rule = e.key().inner().clone();
1116 e.get_mut().0 -= 1;
1117 if e.get().0 == 0 {
1118 if self.is_bus() && msg_type == MessageType::Signal {
1119 fdo::DBusProxy::builder(self)
1120 .cache_properties(CacheProperties::No)
1121 .build()
1122 .await?
1123 .remove_match_rule(rule.clone())
1124 .await?;
1125 }
1126 e.remove();
1127 self.inner
1128 .msg_senders
1129 .lock()
1130 .await
1131 .remove(&Some(rule.into()));
1132 }
1133 Ok(true)
1134 }
1135 }
1136 }
1137
1138 pub(crate) fn queue_remove_match(&self, rule: OwnedMatchRule) {
1139 let conn = self.clone();
1140 let task_name = format!("Remove match `{}`", *rule);
1141 let remove_match =
1142 async move { conn.remove_match(rule).await }.instrument(trace_span!("{}", task_name));
1143 self.inner.executor.spawn(remove_match, &task_name).detach()
1144 }
1145
1146 pub(crate) async fn hello_bus(&self) -> Result<()> {
1147 let dbus_proxy = fdo::DBusProxy::builder(self)
1148 .cache_properties(CacheProperties::No)
1149 .build()
1150 .await?;
1151 let name = dbus_proxy.hello().await?;
1152
1153 self.inner
1154 .unique_name
1155 .set(name)
1156 .expect("Attempted to set unique_name twice");
1158
1159 Ok(())
1160 }
1161
1162 pub(crate) async fn new(
1163 auth: Authenticated<Box<dyn Socket>>,
1164 bus_connection: bool,
1165 executor: Executor<'static>,
1166 ) -> Result<Self> {
1167 #[cfg(unix)]
1168 let cap_unix_fd = auth.cap_unix_fd;
1169
1170 macro_rules! create_msg_broadcast_channel {
1171 ($size:expr) => {{
1172 let (msg_sender, msg_receiver) = broadcast($size);
1173 let mut msg_receiver = msg_receiver.deactivate();
1174 msg_receiver.set_await_active(false);
1175
1176 (msg_sender, msg_receiver)
1177 }};
1178 }
1179 let (msg_sender, msg_receiver) = create_msg_broadcast_channel!(DEFAULT_MAX_QUEUED);
1181 let mut msg_senders = HashMap::new();
1182 msg_senders.insert(None, msg_sender);
1183
1184 let (method_return_sender, method_return_receiver) =
1186 create_msg_broadcast_channel!(DEFAULT_MAX_METHOD_RETURN_QUEUED);
1187 let rule = MatchRule::builder()
1188 .msg_type(MessageType::MethodReturn)
1189 .build()
1190 .into();
1191 msg_senders.insert(Some(rule), method_return_sender.clone());
1192 let rule = MatchRule::builder()
1193 .msg_type(MessageType::Error)
1194 .build()
1195 .into();
1196 msg_senders.insert(Some(rule), method_return_sender);
1197 let msg_senders = Arc::new(Mutex::new(msg_senders));
1198 let subscriptions = Mutex::new(HashMap::new());
1199
1200 let raw_conn = Arc::new(sync::Mutex::new(auth.conn));
1201
1202 let connection = Self {
1203 inner: Arc::new(ConnectionInner {
1204 raw_conn,
1205 server_guid: auth.server_guid,
1206 #[cfg(unix)]
1207 cap_unix_fd,
1208 bus_conn: bus_connection,
1209 serial: AtomicU32::new(1),
1210 unique_name: OnceCell::new(),
1211 subscriptions,
1212 object_server: OnceCell::new(),
1213 object_server_dispatch_task: OnceCell::new(),
1214 executor,
1215 socket_reader_task: OnceCell::new(),
1216 msg_senders,
1217 msg_receiver,
1218 method_return_receiver,
1219 registered_names: Mutex::new(HashMap::new()),
1220 }),
1221 };
1222
1223 Ok(connection)
1224 }
1225
1226 fn next_serial(&self) -> u32 {
1227 self.inner.serial.fetch_add(1, SeqCst)
1228 }
1229
1230 pub async fn session() -> Result<Self> {
1232 ConnectionBuilder::session()?.build().await
1233 }
1234
1235 pub async fn system() -> Result<Self> {
1237 ConnectionBuilder::system()?.build().await
1238 }
1239
1240 pub fn monitor_activity(&self) -> EventListener {
1244 self.inner
1245 .raw_conn
1246 .lock()
1247 .expect("poisoned lock")
1248 .monitor_activity()
1249 }
1250
1251 #[deprecated(
1253 since = "3.13.0",
1254 note = "Use `peer_credentials` instead, which returns `ConnectionCredentials` which includes
1255 the peer PID."
1256 )]
1257 pub fn peer_pid(&self) -> io::Result<Option<u32>> {
1258 self.inner
1259 .raw_conn
1260 .lock()
1261 .expect("poisoned lock")
1262 .socket()
1263 .peer_pid()
1264 }
1265
1266 #[allow(deprecated)]
1275 pub async fn peer_credentials(&self) -> io::Result<ConnectionCredentials> {
1276 let raw_conn = self.inner.raw_conn.lock().expect("poisoned lock");
1277 let socket = raw_conn.socket();
1278
1279 Ok(ConnectionCredentials {
1280 process_id: socket.peer_pid()?,
1281 #[cfg(unix)]
1282 unix_user_id: socket.uid()?,
1283 #[cfg(not(unix))]
1284 unix_user_id: None,
1285 unix_group_ids: None,
1287 #[cfg(windows)]
1288 windows_sid: socket.peer_sid(),
1289 #[cfg(not(windows))]
1290 windows_sid: None,
1291 linux_security_label: None,
1293 })
1294 }
1295
1296 pub(crate) fn init_socket_reader(&self) {
1297 let inner = &self.inner;
1298 inner
1299 .socket_reader_task
1300 .set(
1301 SocketReader::new(inner.raw_conn.clone(), inner.msg_senders.clone())
1302 .spawn(&inner.executor),
1303 )
1304 .expect("Attempted to set `socket_reader_task` twice");
1305 }
1306}
1307
1308impl<T> Sink<T> for Connection
1309where
1310 T: Into<Arc<Message>>,
1311{
1312 type Error = Error;
1313
1314 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1315 <&Connection as Sink<Arc<Message>>>::poll_ready(Pin::new(&mut &*self), cx)
1316 }
1317
1318 fn start_send(self: Pin<&mut Self>, msg: T) -> Result<()> {
1319 Pin::new(&mut &*self).start_send(msg)
1320 }
1321
1322 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1323 <&Connection as Sink<Arc<Message>>>::poll_flush(Pin::new(&mut &*self), cx)
1324 }
1325
1326 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1327 <&Connection as Sink<Arc<Message>>>::poll_close(Pin::new(&mut &*self), cx)
1328 }
1329}
1330
1331impl<'a, T> Sink<T> for &'a Connection
1332where
1333 T: Into<Arc<Message>>,
1334{
1335 type Error = Error;
1336
1337 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
1338 Poll::Ready(Ok(()))
1340 }
1341
1342 fn start_send(self: Pin<&mut Self>, msg: T) -> Result<()> {
1343 let msg = msg.into();
1344
1345 #[cfg(unix)]
1346 if !msg.fds().is_empty() && !self.inner.cap_unix_fd {
1347 return Err(Error::Unsupported);
1348 }
1349
1350 self.inner
1351 .raw_conn
1352 .lock()
1353 .expect("poisoned lock")
1354 .enqueue_message(msg);
1355
1356 Ok(())
1357 }
1358
1359 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1360 self.inner.raw_conn.lock().expect("poisoned lock").flush(cx)
1361 }
1362
1363 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1364 let mut raw_conn = self.inner.raw_conn.lock().expect("poisoned lock");
1365 let res = raw_conn.flush(cx);
1366 match ready!(res) {
1367 Ok(_) => (),
1368 Err(e) => return Poll::Ready(Err(e)),
1369 }
1370
1371 Poll::Ready(raw_conn.close())
1372 }
1373}
1374
1375impl From<crate::blocking::Connection> for Connection {
1376 fn from(conn: crate::blocking::Connection) -> Self {
1377 conn.into_inner()
1378 }
1379}
1380
1381#[derive(Debug)]
1383pub(crate) struct WeakConnection {
1384 inner: Weak<ConnectionInner>,
1385}
1386
1387impl WeakConnection {
1388 pub fn upgrade(&self) -> Option<Connection> {
1390 self.inner.upgrade().map(|inner| Connection { inner })
1391 }
1392}
1393
1394impl From<&Connection> for WeakConnection {
1395 fn from(conn: &Connection) -> Self {
1396 Self {
1397 inner: Arc::downgrade(&conn.inner),
1398 }
1399 }
1400}
1401
1402#[derive(Debug)]
1403enum NameStatus {
1404 Owner(#[allow(unused)] Option<Task<()>>),
1406 Queued(#[allow(unused)] Task<()>),
1408}
1409
1410#[cfg(test)]
1411mod tests {
1412 use futures_util::stream::TryStreamExt;
1413 use ntest::timeout;
1414 use test_log::test;
1415
1416 use crate::{fdo::DBusProxy, AuthMechanism};
1417
1418 use super::*;
1419
1420 async fn test_p2p(
1424 server1: Connection,
1425 client1: Connection,
1426 server2: Connection,
1427 client2: Connection,
1428 ) -> Result<()> {
1429 let forward1 = MessageStream::from(server1.clone()).forward(client2.clone());
1430 let forward2 = MessageStream::from(&client2).forward(server1);
1431 let _forward_task = client1.executor().spawn(
1432 async move { futures_util::try_join!(forward1, forward2) },
1433 "forward_task",
1434 );
1435
1436 let server_ready = Event::new();
1437 let server_ready_listener = server_ready.listen();
1438 let client_done = Event::new();
1439 let client_done_listener = client_done.listen();
1440
1441 let server_future = async move {
1442 let mut stream = MessageStream::from(&server2);
1443 server_ready.notify(1);
1444 let method = loop {
1445 let m = stream.try_next().await?.unwrap();
1446 if m.to_string() == "Method call Test" {
1447 break m;
1448 }
1449 };
1450
1451 server2
1453 .emit_signal(None::<()>, "/", "org.zbus.p2p", "ASignalForYou", &())
1454 .await?;
1455 server2.reply(&method, &("yay")).await?;
1456 client_done_listener.await;
1457
1458 Ok(())
1459 };
1460
1461 let client_future = async move {
1462 let mut stream = MessageStream::from(&client1);
1463 server_ready_listener.await;
1464 let reply = client1
1465 .call_method(None::<()>, "/", Some("org.zbus.p2p"), "Test", &())
1466 .await?;
1467 assert_eq!(reply.to_string(), "Method return");
1468 let m = stream.try_next().await?.unwrap();
1470 client_done.notify(1);
1471 assert_eq!(m.to_string(), "Signal ASignalForYou");
1472 reply.body::<String>()
1473 };
1474
1475 let (val, _) = futures_util::try_join!(client_future, server_future,)?;
1476 assert_eq!(val, "yay");
1477
1478 Ok(())
1479 }
1480
1481 #[test]
1482 #[timeout(15000)]
1483 fn tcp_p2p() {
1484 crate::utils::block_on(test_tcp_p2p()).unwrap();
1485 }
1486
1487 async fn test_tcp_p2p() -> Result<()> {
1488 let (server1, client1) = tcp_p2p_pipe().await?;
1489 let (server2, client2) = tcp_p2p_pipe().await?;
1490
1491 test_p2p(server1, client1, server2, client2).await
1492 }
1493
1494 async fn tcp_p2p_pipe() -> Result<(Connection, Connection)> {
1495 let guid = Guid::generate();
1496
1497 #[cfg(not(feature = "tokio"))]
1498 let (server_conn_builder, client_conn_builder) = {
1499 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
1500 let addr = listener.local_addr().unwrap();
1501 let p1 = std::net::TcpStream::connect(addr).unwrap();
1502 let p0 = listener.incoming().next().unwrap().unwrap();
1503
1504 (
1505 ConnectionBuilder::tcp_stream(p0)
1506 .server(&guid)
1507 .p2p()
1508 .auth_mechanisms(&[AuthMechanism::Anonymous]),
1509 ConnectionBuilder::tcp_stream(p1).p2p(),
1510 )
1511 };
1512
1513 #[cfg(feature = "tokio")]
1514 let (server_conn_builder, client_conn_builder) = {
1515 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1516 let addr = listener.local_addr().unwrap();
1517 let p1 = tokio::net::TcpStream::connect(addr).await.unwrap();
1518 let p0 = listener.accept().await.unwrap().0;
1519
1520 (
1521 ConnectionBuilder::tcp_stream(p0)
1522 .server(&guid)
1523 .p2p()
1524 .auth_mechanisms(&[AuthMechanism::Anonymous]),
1525 ConnectionBuilder::tcp_stream(p1).p2p(),
1526 )
1527 };
1528
1529 futures_util::try_join!(server_conn_builder.build(), client_conn_builder.build())
1530 }
1531
1532 #[cfg(unix)]
1533 #[test]
1534 #[timeout(15000)]
1535 fn unix_p2p() {
1536 crate::utils::block_on(test_unix_p2p()).unwrap();
1537 }
1538
1539 #[cfg(unix)]
1540 async fn test_unix_p2p() -> Result<()> {
1541 let (server1, client1) = unix_p2p_pipe().await?;
1542 let (server2, client2) = unix_p2p_pipe().await?;
1543
1544 test_p2p(server1, client1, server2, client2).await
1545 }
1546
1547 #[cfg(unix)]
1548 async fn unix_p2p_pipe() -> Result<(Connection, Connection)> {
1549 #[cfg(not(feature = "tokio"))]
1550 use std::os::unix::net::UnixStream;
1551 #[cfg(feature = "tokio")]
1552 use tokio::net::UnixStream;
1553 #[cfg(all(windows, not(feature = "tokio")))]
1554 use uds_windows::UnixStream;
1555
1556 let guid = Guid::generate();
1557
1558 let (p0, p1) = UnixStream::pair().unwrap();
1559
1560 futures_util::try_join!(
1561 ConnectionBuilder::unix_stream(p1).p2p().build(),
1562 ConnectionBuilder::unix_stream(p0)
1563 .server(&guid)
1564 .p2p()
1565 .build(),
1566 )
1567 }
1568
1569 #[cfg(any(
1571 all(feature = "vsock", not(feature = "tokio")),
1572 feature = "tokio-vsock"
1573 ))]
1574 #[test]
1575 #[timeout(15000)]
1576 #[ignore]
1577 fn vsock_p2p() {
1578 crate::utils::block_on(test_vsock_p2p()).unwrap();
1579 }
1580
1581 #[cfg(any(
1582 all(feature = "vsock", not(feature = "tokio")),
1583 feature = "tokio-vsock"
1584 ))]
1585 async fn test_vsock_p2p() -> Result<()> {
1586 let (server1, client1) = vsock_p2p_pipe().await?;
1587 let (server2, client2) = vsock_p2p_pipe().await?;
1588
1589 test_p2p(server1, client1, server2, client2).await
1590 }
1591
1592 #[cfg(all(feature = "vsock", not(feature = "tokio")))]
1593 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1594 let guid = Guid::generate();
1595
1596 let listener = vsock::VsockListener::bind_with_cid_port(vsock::VMADDR_CID_ANY, 42).unwrap();
1597 let addr = listener.local_addr().unwrap();
1598 let client = vsock::VsockStream::connect(&addr).unwrap();
1599 let server = listener.incoming().next().unwrap().unwrap();
1600
1601 futures_util::try_join!(
1602 ConnectionBuilder::vsock_stream(server)
1603 .server(&guid)
1604 .p2p()
1605 .auth_mechanisms(&[AuthMechanism::Anonymous])
1606 .build(),
1607 ConnectionBuilder::vsock_stream(client).p2p().build(),
1608 )
1609 }
1610
1611 #[cfg(feature = "tokio-vsock")]
1612 async fn vsock_p2p_pipe() -> Result<(Connection, Connection)> {
1613 let guid = Guid::generate();
1614
1615 let listener = tokio_vsock::VsockListener::bind(2, 42).unwrap();
1616 let client = tokio_vsock::VsockStream::connect(3, 42).await.unwrap();
1617 let server = listener.incoming().next().await.unwrap().unwrap();
1618
1619 futures_util::try_join!(
1620 ConnectionBuilder::vsock_stream(server)
1621 .server(&guid)
1622 .p2p()
1623 .auth_mechanisms(&[AuthMechanism::Anonymous])
1624 .build(),
1625 ConnectionBuilder::vsock_stream(client).p2p().build(),
1626 )
1627 }
1628
1629 #[test]
1630 #[timeout(15000)]
1631 fn serial_monotonically_increases() {
1632 crate::utils::block_on(test_serial_monotonically_increases());
1633 }
1634
1635 async fn test_serial_monotonically_increases() {
1636 let c = Connection::session().await.unwrap();
1637 let serial = c.next_serial() + 1;
1638
1639 for next in serial..serial + 10 {
1640 assert_eq!(next, c.next_serial());
1641 }
1642 }
1643
1644 #[cfg(all(windows, feature = "windows-gdbus"))]
1645 #[test]
1646 fn connect_gdbus_session_bus() {
1647 let addr = crate::win32::windows_autolaunch_bus_address()
1648 .expect("Unable to get GDBus session bus address");
1649
1650 crate::block_on(async { addr.connect().await }).expect("Unable to connect to session bus");
1651 }
1652
1653 #[cfg(target_os = "macos")]
1654 #[test]
1655 fn connect_launchd_session_bus() {
1656 crate::block_on(async {
1657 let addr = crate::address::macos_launchd_bus_address("DBUS_LAUNCHD_SESSION_BUS_SOCKET")
1658 .await
1659 .expect("Unable to get Launchd session bus address");
1660 addr.connect().await
1661 })
1662 .expect("Unable to connect to session bus");
1663 }
1664
1665 #[test]
1666 #[timeout(15000)]
1667 fn disconnect_on_drop() {
1668 crate::utils::block_on(test_disconnect_on_drop());
1671 }
1672
1673 async fn test_disconnect_on_drop() {
1674 #[derive(Default)]
1675 struct MyInterface {}
1676
1677 #[crate::dbus_interface(name = "dev.peelz.FooBar.Baz")]
1678 impl MyInterface {
1679 fn do_thing(&self) {}
1680 }
1681 let name = "dev.peelz.foobar";
1682 let connection = ConnectionBuilder::session()
1683 .unwrap()
1684 .name(name)
1685 .unwrap()
1686 .serve_at("/dev/peelz/FooBar", MyInterface::default())
1687 .unwrap()
1688 .build()
1689 .await
1690 .unwrap();
1691
1692 let connection2 = Connection::session().await.unwrap();
1693 let dbus = DBusProxy::new(&connection2).await.unwrap();
1694 let mut stream = dbus
1695 .receive_name_owner_changed_with_args(&[(0, name), (2, "")])
1696 .await
1697 .unwrap();
1698
1699 drop(connection);
1700
1701 stream.next().await.unwrap();
1703
1704 let name_has_owner = dbus.name_has_owner(name.try_into().unwrap()).await.unwrap();
1706 assert!(!name_has_owner);
1707 }
1708
1709 #[cfg(any(unix, not(feature = "tokio")))]
1710 #[test]
1711 #[timeout(15000)]
1712 fn unix_p2p_cookie_auth() {
1713 use crate::utils::block_on;
1714 use std::{
1715 fs::{create_dir_all, remove_file, write},
1716 time::{SystemTime as Time, UNIX_EPOCH},
1717 };
1718 #[cfg(unix)]
1719 use std::{
1720 fs::{set_permissions, Permissions},
1721 os::unix::fs::PermissionsExt,
1722 };
1723 use xdg_home::home_dir;
1724
1725 let cookie_context = "zbus-test-cookie-context";
1726 let cookie_id = 123456789;
1727 let cookie = hex::encode(b"our cookie");
1728
1729 let cookie_dir = home_dir().unwrap().join(".dbus-keyrings");
1731 create_dir_all(&cookie_dir).unwrap();
1732 #[cfg(unix)]
1733 set_permissions(&cookie_dir, Permissions::from_mode(0o700)).unwrap();
1734
1735 let cookie_file = cookie_dir.join(cookie_context);
1737 let ts = Time::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
1738 let cookie_entry = format!("{cookie_id} {ts} {cookie}");
1739 write(&cookie_file, cookie_entry).unwrap();
1740
1741 let res1 = block_on(test_unix_p2p_cookie_auth(cookie_context, Some(cookie_id)));
1743 let res2 = block_on(test_unix_p2p_cookie_auth(cookie_context, None));
1745
1746 remove_file(&cookie_file).unwrap();
1748
1749 res1.unwrap();
1750 res2.unwrap();
1751 }
1752
1753 #[cfg(any(unix, not(feature = "tokio")))]
1754 async fn test_unix_p2p_cookie_auth(
1755 cookie_context: &'static str,
1756 cookie_id: Option<usize>,
1757 ) -> Result<()> {
1758 #[cfg(all(unix, not(feature = "tokio")))]
1759 use std::os::unix::net::UnixStream;
1760 #[cfg(all(unix, feature = "tokio"))]
1761 use tokio::net::UnixStream;
1762 #[cfg(all(windows, not(feature = "tokio")))]
1763 use uds_windows::UnixStream;
1764
1765 let guid = Guid::generate();
1766
1767 let (p0, p1) = UnixStream::pair().unwrap();
1768 let mut server_builder = ConnectionBuilder::unix_stream(p0)
1769 .server(&guid)
1770 .p2p()
1771 .auth_mechanisms(&[AuthMechanism::Cookie])
1772 .cookie_context(cookie_context)
1773 .unwrap();
1774 if let Some(cookie_id) = cookie_id {
1775 server_builder = server_builder.cookie_id(cookie_id);
1776 }
1777
1778 futures_util::try_join!(
1779 ConnectionBuilder::unix_stream(p1).p2p().build(),
1780 server_builder.build(),
1781 )
1782 .map(|_| ())
1783 }
1784}