zbus/
connection_builder.rs

1#[cfg(not(feature = "tokio"))]
2use async_io::Async;
3use event_listener::Event;
4use static_assertions::assert_impl_all;
5#[cfg(not(feature = "tokio"))]
6use std::net::TcpStream;
7#[cfg(all(unix, not(feature = "tokio")))]
8use std::os::unix::net::UnixStream;
9use std::{
10    collections::{HashMap, HashSet, VecDeque},
11    convert::TryInto,
12    sync::Arc,
13};
14#[cfg(feature = "tokio")]
15use tokio::net::TcpStream;
16#[cfg(all(unix, feature = "tokio"))]
17use tokio::net::UnixStream;
18#[cfg(feature = "tokio-vsock")]
19use tokio_vsock::VsockStream;
20#[cfg(windows)]
21use uds_windows::UnixStream;
22#[cfg(all(feature = "vsock", not(feature = "tokio")))]
23use vsock::VsockStream;
24
25use zvariant::{ObjectPath, Str};
26
27use crate::{
28    address::{self, Address},
29    async_lock::RwLock,
30    handshake,
31    names::{InterfaceName, UniqueName, WellKnownName},
32    raw::Socket,
33    AuthMechanism, Authenticated, Connection, Error, Executor, Guid, Interface, Result,
34};
35
36const DEFAULT_MAX_QUEUED: usize = 64;
37
38#[derive(Debug)]
39enum Target {
40    UnixStream(UnixStream),
41    TcpStream(TcpStream),
42    #[cfg(any(
43        all(feature = "vsock", not(feature = "tokio")),
44        feature = "tokio-vsock"
45    ))]
46    VsockStream(VsockStream),
47    Address(Address),
48    Socket(Box<dyn Socket>),
49}
50
51type Interfaces<'a> =
52    HashMap<ObjectPath<'a>, HashMap<InterfaceName<'static>, Arc<RwLock<dyn Interface>>>>;
53
54/// A builder for [`zbus::Connection`].
55#[derive(derivative::Derivative)]
56#[derivative(Debug)]
57#[must_use]
58pub struct ConnectionBuilder<'a> {
59    target: Target,
60    max_queued: Option<usize>,
61    guid: Option<&'a Guid>,
62    p2p: bool,
63    internal_executor: bool,
64    #[derivative(Debug = "ignore")]
65    interfaces: Interfaces<'a>,
66    names: HashSet<WellKnownName<'a>>,
67    auth_mechanisms: Option<VecDeque<AuthMechanism>>,
68    unique_name: Option<UniqueName<'a>>,
69    cookie_context: Option<handshake::CookieContext<'a>>,
70    cookie_id: Option<usize>,
71}
72
73assert_impl_all!(ConnectionBuilder<'_>: Send, Sync, Unpin);
74
75impl<'a> ConnectionBuilder<'a> {
76    /// Create a builder for the session/user message bus connection.
77    pub fn session() -> Result<Self> {
78        Ok(Self::new(Target::Address(Address::session()?)))
79    }
80
81    /// Create a builder for the system-wide message bus connection.
82    pub fn system() -> Result<Self> {
83        Ok(Self::new(Target::Address(Address::system()?)))
84    }
85
86    /// Create a builder for connection that will use the given [D-Bus bus address].
87    ///
88    /// # Example
89    ///
90    /// Here is an example of connecting to an IBus service:
91    ///
92    /// ```no_run
93    /// # use std::error::Error;
94    /// # use zbus::ConnectionBuilder;
95    /// # use zbus::block_on;
96    /// #
97    /// # block_on(async {
98    /// let addr = "unix:\
99    ///     path=/home/zeenix/.cache/ibus/dbus-ET0Xzrk9,\
100    ///     guid=fdd08e811a6c7ebe1fef0d9e647230da";
101    /// let conn = ConnectionBuilder::address(addr)?
102    ///     .build()
103    ///     .await?;
104    ///
105    /// // Do something useful with `conn`..
106    /// #     drop(conn);
107    /// #     Ok::<(), zbus::Error>(())
108    /// # }).unwrap();
109    /// #
110    /// # Ok::<_, Box<dyn Error + Send + Sync>>(())
111    /// ```
112    ///
113    /// **Note:** The IBus address is different for each session. You can find the address for your
114    /// current session using `ibus address` command.
115    ///
116    /// [D-Bus bus address]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses
117    pub fn address<A>(address: A) -> Result<Self>
118    where
119        A: TryInto<Address>,
120        A::Error: Into<Error>,
121    {
122        Ok(Self::new(Target::Address(
123            address.try_into().map_err(Into::into)?,
124        )))
125    }
126
127    /// Create a builder for connection that will use the given unix stream.
128    ///
129    /// If the default `async-io` feature is disabled, this method will expect
130    /// [`tokio::net::UnixStream`](https://docs.rs/tokio/latest/tokio/net/struct.UnixStream.html)
131    /// argument.
132    pub fn unix_stream(stream: UnixStream) -> Self {
133        Self::new(Target::UnixStream(stream))
134    }
135
136    /// Create a builder for connection that will use the given TCP stream.
137    ///
138    /// If the default `async-io` feature is disabled, this method will expect
139    /// [`tokio::net::TcpStream`](https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html)
140    /// argument.
141    pub fn tcp_stream(stream: TcpStream) -> Self {
142        Self::new(Target::TcpStream(stream))
143    }
144
145    /// Create a builder for connection that will use the given VSOCK stream.
146    ///
147    /// This method is only available when either `vsock` or `tokio-vsock` feature is enabled. The
148    /// type of `stream` is `vsock::VsockStream` with `vsock` feature and `tokio_vsock::VsockStream`
149    /// with `tokio-vsock` feature.
150    #[cfg(any(
151        all(feature = "vsock", not(feature = "tokio")),
152        feature = "tokio-vsock"
153    ))]
154    pub fn vsock_stream(stream: VsockStream) -> Self {
155        Self::new(Target::VsockStream(stream))
156    }
157
158    /// Create a builder for connection that will use the given socket.
159    pub fn socket<S: Socket + 'static>(socket: S) -> Self {
160        Self::new(Target::Socket(Box::new(socket)))
161    }
162
163    /// Specify the mechanisms to use during authentication.
164    pub fn auth_mechanisms(mut self, auth_mechanisms: &[AuthMechanism]) -> Self {
165        self.auth_mechanisms = Some(VecDeque::from(auth_mechanisms.to_vec()));
166
167        self
168    }
169
170    /// The cookie context to use during authentication.
171    ///
172    /// This is only used when the `cookie` authentication mechanism is enabled and only valid for
173    /// server connection.
174    ///
175    /// If not specified, the default cookie context of `org_freedesktop_general` will be used.
176    ///
177    /// # Errors
178    ///
179    /// If the given string is not a valid cookie context.
180    pub fn cookie_context<C>(mut self, context: C) -> Result<Self>
181    where
182        C: Into<Str<'a>>,
183    {
184        self.cookie_context = Some(context.into().try_into()?);
185
186        Ok(self)
187    }
188
189    /// The ID of the cookie to use during authentication.
190    ///
191    /// This is only used when the `cookie` authentication mechanism is enabled and only valid for
192    /// server connection.
193    ///
194    /// If not specified, the first cookie found in the cookie context file will be used.
195    pub fn cookie_id(mut self, id: usize) -> Self {
196        self.cookie_id = Some(id);
197
198        self
199    }
200
201    /// The to-be-created connection will be a peer-to-peer connection.
202    pub fn p2p(mut self) -> Self {
203        self.p2p = true;
204
205        self
206    }
207
208    /// The to-be-created connection will be a server using the given GUID.
209    ///
210    /// The to-be-created connection will wait for incoming client authentication handshake and
211    /// negotiation messages, for peer-to-peer communications after successful creation.
212    pub fn server(mut self, guid: &'a Guid) -> Self {
213        self.guid = Some(guid);
214
215        self
216    }
217
218    /// Set the capacity of the main (unfiltered) queue.
219    ///
220    /// Since typically you'd want to set this at instantiation time, you can set it through the
221    /// builder.
222    ///
223    /// # Example
224    ///
225    /// ```
226    /// # use std::error::Error;
227    /// # use zbus::ConnectionBuilder;
228    /// # use zbus::block_on;
229    /// #
230    /// # block_on(async {
231    /// let conn = ConnectionBuilder::session()?
232    ///     .max_queued(30)
233    ///     .build()
234    ///     .await?;
235    /// assert_eq!(conn.max_queued(), 30);
236    ///
237    /// #     Ok::<(), zbus::Error>(())
238    /// # }).unwrap();
239    /// #
240    /// // Do something useful with `conn`..
241    /// # Ok::<_, Box<dyn Error + Send + Sync>>(())
242    /// ```
243    pub fn max_queued(mut self, max: usize) -> Self {
244        self.max_queued = Some(max);
245
246        self
247    }
248
249    /// Enable or disable the internal executor thread.
250    ///
251    /// The thread is enabled by default.
252    ///
253    /// See [Connection::executor] for more details.
254    pub fn internal_executor(mut self, enabled: bool) -> Self {
255        self.internal_executor = enabled;
256
257        self
258    }
259
260    /// Register a D-Bus [`Interface`] to be served at a given path.
261    ///
262    /// This is similar to [`zbus::ObjectServer::at`], except that it allows you to have your
263    /// interfaces available immediately after the connection is established. Typically, this is
264    /// exactly what you'd want. Also in contrast to [`zbus::ObjectServer::at`], this method will
265    /// replace any previously added interface with the same name at the same path.
266    pub fn serve_at<P, I>(mut self, path: P, iface: I) -> Result<Self>
267    where
268        I: Interface,
269        P: TryInto<ObjectPath<'a>>,
270        P::Error: Into<Error>,
271    {
272        let path = path.try_into().map_err(Into::into)?;
273        let entry = self.interfaces.entry(path).or_default();
274        entry.insert(I::name(), Arc::new(RwLock::new(iface)));
275
276        Ok(self)
277    }
278
279    /// Register a well-known name for this connection on the bus.
280    ///
281    /// This is similar to [`zbus::Connection::request_name`], except the name is requested as part
282    /// of the connection setup ([`ConnectionBuilder::build`]), immediately after interfaces
283    /// registered (through [`ConnectionBuilder::serve_at`]) are advertised. Typically this is
284    /// exactly what you want.
285    pub fn name<W>(mut self, well_known_name: W) -> Result<Self>
286    where
287        W: TryInto<WellKnownName<'a>>,
288        W::Error: Into<Error>,
289    {
290        let well_known_name = well_known_name.try_into().map_err(Into::into)?;
291        self.names.insert(well_known_name);
292
293        Ok(self)
294    }
295
296    /// Sets the unique name of the connection.
297    ///
298    /// # Panics
299    ///
300    /// This method panics if the to-be-created connection is not a peer-to-peer connection.
301    /// It will always panic if the connection is to a message bus as it's the bus that assigns
302    /// peers their unique names. This is mainly provided for bus implementations. All other users
303    /// should not need to use this method.
304    pub fn unique_name<U>(mut self, unique_name: U) -> Result<Self>
305    where
306        U: TryInto<UniqueName<'a>>,
307        U::Error: Into<Error>,
308    {
309        if !self.p2p {
310            panic!("unique name can only be set for peer-to-peer connections");
311        }
312        let name = unique_name.try_into().map_err(Into::into)?;
313        self.unique_name = Some(name);
314
315        Ok(self)
316    }
317
318    /// Build the connection, consuming the builder.
319    ///
320    /// # Errors
321    ///
322    /// Until server-side bus connection is supported, attempting to build such a connection will
323    /// result in [`Error::Unsupported`] error.
324    pub async fn build(self) -> Result<Connection> {
325        let executor = Executor::new();
326        #[cfg(not(feature = "tokio"))]
327        let internal_executor = self.internal_executor;
328        // Box the future as it's large and can cause stack overflow.
329        let conn = Box::pin(executor.run(self.build_(executor.clone()))).await?;
330
331        #[cfg(not(feature = "tokio"))]
332        start_internal_executor(&executor, internal_executor)?;
333
334        Ok(conn)
335    }
336
337    async fn build_(self, executor: Executor<'static>) -> Result<Connection> {
338        let stream = match self.target {
339            #[cfg(not(feature = "tokio"))]
340            Target::UnixStream(stream) => Box::new(Async::new(stream)?) as Box<dyn Socket>,
341            #[cfg(all(unix, feature = "tokio"))]
342            Target::UnixStream(stream) => Box::new(stream) as Box<dyn Socket>,
343            #[cfg(all(not(unix), feature = "tokio"))]
344            Target::UnixStream(_) => return Err(Error::Unsupported),
345            #[cfg(not(feature = "tokio"))]
346            Target::TcpStream(stream) => Box::new(Async::new(stream)?) as Box<dyn Socket>,
347            #[cfg(feature = "tokio")]
348            Target::TcpStream(stream) => Box::new(stream) as Box<dyn Socket>,
349            #[cfg(all(feature = "vsock", not(feature = "tokio")))]
350            Target::VsockStream(stream) => Box::new(Async::new(stream)?) as Box<dyn Socket>,
351            #[cfg(feature = "tokio-vsock")]
352            Target::VsockStream(stream) => Box::new(stream) as Box<dyn Socket>,
353            Target::Address(address) => match address.connect().await? {
354                #[cfg(any(unix, not(feature = "tokio")))]
355                address::Stream::Unix(stream) => Box::new(stream) as Box<dyn Socket>,
356                address::Stream::Tcp(stream) => Box::new(stream) as Box<dyn Socket>,
357                #[cfg(any(
358                    all(feature = "vsock", not(feature = "tokio")),
359                    feature = "tokio-vsock"
360                ))]
361                address::Stream::Vsock(stream) => Box::new(stream) as Box<dyn Socket>,
362            },
363            Target::Socket(stream) => stream,
364        };
365        let auth = match self.guid {
366            None => {
367                // SASL Handshake
368                Authenticated::client(stream, self.auth_mechanisms).await?
369            }
370            Some(guid) => {
371                if !self.p2p {
372                    return Err(Error::Unsupported);
373                }
374
375                #[cfg(unix)]
376                let client_uid = stream.uid()?;
377
378                #[cfg(windows)]
379                let client_sid = stream.peer_sid();
380
381                Authenticated::server(
382                    stream,
383                    guid.clone(),
384                    #[cfg(unix)]
385                    client_uid,
386                    #[cfg(windows)]
387                    client_sid,
388                    self.auth_mechanisms,
389                    self.cookie_id,
390                    self.cookie_context.unwrap_or_default(),
391                )
392                .await?
393            }
394        };
395
396        let mut conn = Connection::new(auth, !self.p2p, executor).await?;
397        conn.set_max_queued(self.max_queued.unwrap_or(DEFAULT_MAX_QUEUED));
398        if let Some(unique_name) = self.unique_name {
399            conn.set_unique_name(unique_name)?;
400        }
401
402        if !self.interfaces.is_empty() {
403            let object_server = conn.sync_object_server(false, None);
404            for (path, interfaces) in self.interfaces {
405                for (name, iface) in interfaces {
406                    let future = object_server.at_ready(path.to_owned(), name, || iface);
407                    let added = future.await?;
408                    // Duplicates shouldn't happen.
409                    assert!(added);
410                }
411            }
412
413            let started_event = Event::new();
414            let listener = started_event.listen();
415            conn.start_object_server(Some(started_event));
416
417            listener.await;
418        }
419
420        // Start the socket reader task.
421        conn.init_socket_reader();
422
423        if !self.p2p {
424            // Now that the server has approved us, we must send the bus Hello, as per specs
425            conn.hello_bus().await?;
426        }
427
428        for name in self.names {
429            conn.request_name(name).await?;
430        }
431
432        Ok(conn)
433    }
434
435    fn new(target: Target) -> Self {
436        Self {
437            target,
438            p2p: false,
439            max_queued: None,
440            guid: None,
441            internal_executor: true,
442            interfaces: HashMap::new(),
443            names: HashSet::new(),
444            auth_mechanisms: None,
445            unique_name: None,
446            cookie_id: None,
447            cookie_context: None,
448        }
449    }
450}
451
452/// Start the internal executor thread.
453///
454/// Returns a dummy task that keep the executor ticking thread from exiting due to absence of any
455/// tasks until socket reader task kicks in.
456#[cfg(not(feature = "tokio"))]
457fn start_internal_executor(executor: &Executor<'static>, internal_executor: bool) -> Result<()> {
458    if internal_executor {
459        let executor = executor.clone();
460        std::thread::Builder::new()
461            .name("zbus::Connection executor".into())
462            .spawn(move || {
463                crate::utils::block_on(async move {
464                    // Run as long as there is a task to run.
465                    while !executor.is_empty() {
466                        executor.tick().await;
467                    }
468                })
469            })?;
470    }
471
472    Ok(())
473}