zbus/message/
builder.rs

1use std::{
2    borrow::Cow,
3    io::{Cursor, Write},
4    num::NonZeroU32,
5    sync::Arc,
6};
7#[cfg(unix)]
8use zvariant::OwnedFd;
9
10use enumflags2::BitFlags;
11use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
12use zvariant::{serialized, Endian, Signature};
13
14use crate::{
15    message::{EndianSig, Fields, Flags, Header, Message, PrimaryHeader, Sequence, Type},
16    utils::padding_for_8_bytes,
17    zvariant::{serialized::Context, DynamicType, ObjectPath},
18    Error, Result,
19};
20
21use crate::message::header::MAX_MESSAGE_SIZE;
22
23#[cfg(unix)]
24type BuildGenericResult = Vec<OwnedFd>;
25
26#[cfg(not(unix))]
27type BuildGenericResult = ();
28
29macro_rules! dbus_context {
30    ($self:ident, $n_bytes_before: expr) => {
31        Context::new_dbus($self.header.primary().endian_sig().into(), $n_bytes_before)
32    };
33}
34
35/// A builder for a [`Message`].
36#[derive(Debug, Clone)]
37pub struct Builder<'a> {
38    header: Header<'a>,
39}
40
41impl<'a> Builder<'a> {
42    pub(super) fn new(msg_type: Type) -> Self {
43        let primary = PrimaryHeader::new(msg_type, 0);
44        let fields = Fields::new();
45        let header = Header::new(primary, fields);
46        Self { header }
47    }
48
49    /// Add flags to the message.
50    ///
51    /// See [`Flags`] documentation for the meaning of the flags.
52    ///
53    /// The function will return an error if invalid flags are given for the message type.
54    pub fn with_flags(mut self, flag: Flags) -> Result<Self> {
55        if self.header.message_type() != Type::MethodCall
56            && BitFlags::from_flag(flag).contains(Flags::NoReplyExpected)
57        {
58            return Err(Error::InvalidField);
59        }
60        let flags = self.header.primary().flags() | flag;
61        self.header.primary_mut().set_flags(flags);
62        Ok(self)
63    }
64
65    /// Set the unique name of the sending connection.
66    pub fn sender<'s: 'a, S>(mut self, sender: S) -> Result<Self>
67    where
68        S: TryInto<UniqueName<'s>>,
69        S::Error: Into<Error>,
70    {
71        self.header.fields_mut().sender = Some(sender.try_into().map_err(Into::into)?);
72        Ok(self)
73    }
74
75    /// Set the object to send a call to, or the object a signal is emitted from.
76    pub fn path<'p: 'a, P>(mut self, path: P) -> Result<Self>
77    where
78        P: TryInto<ObjectPath<'p>>,
79        P::Error: Into<Error>,
80    {
81        self.header.fields_mut().path = Some(path.try_into().map_err(Into::into)?);
82        Ok(self)
83    }
84
85    /// Set the interface to invoke a method call on, or that a signal is emitted from.
86    pub fn interface<'i: 'a, I>(mut self, interface: I) -> Result<Self>
87    where
88        I: TryInto<InterfaceName<'i>>,
89        I::Error: Into<Error>,
90    {
91        self.header.fields_mut().interface = Some(interface.try_into().map_err(Into::into)?);
92        Ok(self)
93    }
94
95    /// Set the member, either the method name or signal name.
96    pub fn member<'m: 'a, M>(mut self, member: M) -> Result<Self>
97    where
98        M: TryInto<MemberName<'m>>,
99        M::Error: Into<Error>,
100    {
101        self.header.fields_mut().member = Some(member.try_into().map_err(Into::into)?);
102        Ok(self)
103    }
104
105    pub(super) fn error_name<'e: 'a, E>(mut self, error: E) -> Result<Self>
106    where
107        E: TryInto<ErrorName<'e>>,
108        E::Error: Into<Error>,
109    {
110        self.header.fields_mut().error_name = Some(error.try_into().map_err(Into::into)?);
111        Ok(self)
112    }
113
114    /// Set the name of the connection this message is intended for.
115    pub fn destination<'d: 'a, D>(mut self, destination: D) -> Result<Self>
116    where
117        D: TryInto<BusName<'d>>,
118        D::Error: Into<Error>,
119    {
120        self.header.fields_mut().destination = Some(destination.try_into().map_err(Into::into)?);
121        Ok(self)
122    }
123
124    /// Override the generated or inherited serial.  This is a low level modification,
125    /// generally you should not need to use this.
126    pub fn serial(mut self, serial: NonZeroU32) -> Self {
127        self.header.primary_mut().set_serial_num(serial);
128        self
129    }
130
131    /// Override the reply serial. This is a low level modification, generally you should use
132    /// `Message::method_return` instead.
133    pub fn reply_serial(mut self, serial: Option<NonZeroU32>) -> Self {
134        self.header.fields_mut().reply_serial = serial;
135        self
136    }
137
138    pub(super) fn reply_to(mut self, reply_to: &Header<'_>) -> Result<Self> {
139        let serial = reply_to.primary().serial_num();
140        self.header.fields_mut().reply_serial = Some(serial);
141        self = self.endian(reply_to.primary().endian_sig().into());
142
143        if let Some(sender) = reply_to.sender() {
144            self.destination(sender.to_owned())
145        } else {
146            Ok(self)
147        }
148    }
149
150    /// Set the endianness of the message.
151    ///
152    /// The default endianness is native.
153    pub fn endian(mut self, endian: Endian) -> Self {
154        let sig = EndianSig::from(endian);
155        self.header.primary_mut().set_endian_sig(sig);
156
157        self
158    }
159
160    /// Build the [`Message`] with the given body.
161    ///
162    /// You may pass `()` as the body if the message has no body.
163    ///
164    /// The caller is currently required to ensure that the resulting message contains the headers
165    /// as compliant with the [specification]. Additional checks may be added to this builder over
166    /// time as needed.
167    ///
168    /// [specification]:
169    /// https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-header-fields
170    pub fn build<B>(self, body: &B) -> Result<Message>
171    where
172        B: serde::ser::Serialize + DynamicType,
173    {
174        let ctxt = dbus_context!(self, 0);
175
176        // Note: this iterates the body twice, but we prefer efficient handling of large messages
177        // to efficient handling of ones that are complex to serialize.
178        let body_size = zvariant::serialized_size(ctxt, body)?;
179
180        let signature = body.signature();
181
182        self.build_generic(signature, body_size, move |cursor| {
183            // SAFETY: build_generic puts FDs and the body in the same Message.
184            unsafe { zvariant::to_writer(cursor, ctxt, body) }
185                .map(|s| {
186                    #[cfg(unix)]
187                    {
188                        s.into_fds()
189                    }
190                    #[cfg(not(unix))]
191                    {
192                        let _ = s;
193                    }
194                })
195                .map_err(Into::into)
196        })
197    }
198
199    /// Create a new message from a raw slice of bytes to populate the body with, rather than by
200    /// serializing a value. The message body will be the exact bytes.
201    ///
202    /// # Safety
203    ///
204    /// This method is unsafe because it can be used to build an invalid message.
205    pub unsafe fn build_raw_body<S>(
206        self,
207        body_bytes: &[u8],
208        signature: S,
209        #[cfg(unix)] fds: Vec<OwnedFd>,
210    ) -> Result<Message>
211    where
212        S: TryInto<Signature>,
213        S::Error: Into<Error>,
214    {
215        let signature = signature.try_into().map_err(Into::into)?;
216        let body_size = serialized::Size::new(body_bytes.len(), dbus_context!(self, 0));
217        #[cfg(unix)]
218        let body_size = {
219            let num_fds = fds.len().try_into().map_err(|_| Error::ExcessData)?;
220            body_size.set_num_fds(num_fds)
221        };
222
223        self.build_generic(
224            signature,
225            body_size,
226            move |cursor: &mut Cursor<&mut Vec<u8>>| {
227                cursor.write_all(body_bytes)?;
228
229                #[cfg(unix)]
230                return Ok::<Vec<OwnedFd>, Error>(fds);
231
232                #[cfg(not(unix))]
233                return Ok::<(), Error>(());
234            },
235        )
236    }
237
238    fn build_generic<WriteFunc>(
239        self,
240        signature: Signature,
241        body_size: serialized::Size,
242        write_body: WriteFunc,
243    ) -> Result<Message>
244    where
245        WriteFunc: FnOnce(&mut Cursor<&mut Vec<u8>>) -> Result<BuildGenericResult>,
246    {
247        let ctxt = dbus_context!(self, 0);
248        let mut header = self.header;
249
250        header.fields_mut().signature = Cow::Owned(signature);
251
252        let body_len_u32 = body_size.size().try_into().map_err(|_| Error::ExcessData)?;
253        header.primary_mut().set_body_len(body_len_u32);
254
255        #[cfg(unix)]
256        {
257            let fds_len = body_size.num_fds();
258            if fds_len != 0 {
259                header.fields_mut().unix_fds = Some(fds_len);
260            }
261        }
262
263        let hdr_len = *zvariant::serialized_size(ctxt, &header)?;
264        // We need to align the body to 8-byte boundary.
265        let body_padding = padding_for_8_bytes(hdr_len);
266        let body_offset = hdr_len + body_padding;
267        let total_len = body_offset + body_size.size();
268        if total_len > MAX_MESSAGE_SIZE {
269            return Err(Error::ExcessData);
270        }
271        let mut bytes: Vec<u8> = Vec::with_capacity(total_len);
272        let mut cursor = Cursor::new(&mut bytes);
273
274        // SAFETY: There are no FDs involved.
275        unsafe { zvariant::to_writer(&mut cursor, ctxt, &header) }?;
276        cursor.write_all(&[0u8; 8][..body_padding])?;
277        #[cfg(unix)]
278        let fds: Vec<_> = write_body(&mut cursor)?.into_iter().collect();
279        #[cfg(not(unix))]
280        write_body(&mut cursor)?;
281
282        let primary_header = header.into_primary();
283        #[cfg(unix)]
284        let bytes = serialized::Data::new_fds(bytes, ctxt, fds);
285        #[cfg(not(unix))]
286        let bytes = serialized::Data::new(bytes, ctxt);
287
288        Ok(Message {
289            inner: Arc::new(super::Inner {
290                primary_header,
291                quick_fields: std::sync::OnceLock::new(),
292                bytes,
293                body_offset,
294                recv_seq: Sequence::default(),
295            }),
296        })
297    }
298}
299
300impl<'m> From<Header<'m>> for Builder<'m> {
301    fn from(mut header: Header<'m>) -> Self {
302        // Signature and Fds are added by body* methods.
303        let fields = header.fields_mut();
304        fields.signature = Cow::Owned(Signature::Unit);
305        fields.unix_fds = None;
306
307        Self { header }
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::Message;
314    use crate::Error;
315    use test_log::test;
316
317    #[test]
318    fn test_raw() -> Result<(), Error> {
319        let raw_body: &[u8] = &[16, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0];
320        let message_builder = Message::signal("/", "test.test", "test")?;
321        let message = unsafe {
322            message_builder.build_raw_body(
323                raw_body,
324                "ai",
325                #[cfg(unix)]
326                vec![],
327            )?
328        };
329
330        let output: Vec<i32> = message.body().deserialize()?;
331        assert_eq!(output, vec![1, 2, 3, 4]);
332
333        Ok(())
334    }
335}