zbus/
message_builder.rs

1use std::{
2    convert::TryInto,
3    io::{Cursor, Write},
4};
5
6#[cfg(unix)]
7use crate::Fds;
8#[cfg(unix)]
9use std::{
10    os::unix::io::RawFd,
11    sync::{Arc, RwLock},
12};
13
14use enumflags2::BitFlags;
15use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
16
17use crate::{
18    utils::padding_for_8_bytes,
19    zvariant::{DynamicType, EncodingContext, ObjectPath, Signature},
20    Error, Message, MessageField, MessageFieldCode, MessageFields, MessageFlags, MessageHeader,
21    MessagePrimaryHeader, MessageSequence, MessageType, QuickMessageFields, Result,
22    MAX_MESSAGE_SIZE,
23};
24
25#[cfg(unix)]
26type BuildGenericResult = Vec<RawFd>;
27
28#[cfg(not(unix))]
29type BuildGenericResult = ();
30
31macro_rules! dbus_context {
32    ($n_bytes_before: expr) => {
33        EncodingContext::<byteorder::NativeEndian>::new_dbus($n_bytes_before)
34    };
35}
36
37/// A builder for [`Message`]
38#[derive(Debug, Clone)]
39pub struct MessageBuilder<'a> {
40    header: MessageHeader<'a>,
41}
42
43impl<'a> MessageBuilder<'a> {
44    fn new(msg_type: MessageType) -> Self {
45        let primary = MessagePrimaryHeader::new(msg_type, 0);
46        let fields = MessageFields::new();
47        let header = MessageHeader::new(primary, fields);
48        Self { header }
49    }
50
51    /// Create a message of type [`MessageType::MethodCall`].
52    pub fn method_call<'p: 'a, 'm: 'a, P, M>(path: P, method_name: M) -> Result<Self>
53    where
54        P: TryInto<ObjectPath<'p>>,
55        M: TryInto<MemberName<'m>>,
56        P::Error: Into<Error>,
57        M::Error: Into<Error>,
58    {
59        Self::new(MessageType::MethodCall)
60            .path(path)?
61            .member(method_name)
62    }
63
64    /// Create a message of type [`MessageType::Signal`].
65    pub fn signal<'p: 'a, 'i: 'a, 'm: 'a, P, I, M>(path: P, interface: I, name: M) -> Result<Self>
66    where
67        P: TryInto<ObjectPath<'p>>,
68        I: TryInto<InterfaceName<'i>>,
69        M: TryInto<MemberName<'m>>,
70        P::Error: Into<Error>,
71        I::Error: Into<Error>,
72        M::Error: Into<Error>,
73    {
74        Self::new(MessageType::Signal)
75            .path(path)?
76            .interface(interface)?
77            .member(name)
78    }
79
80    /// Create a message of type [`MessageType::MethodReturn`].
81    pub fn method_return(reply_to: &MessageHeader<'_>) -> Result<Self> {
82        Self::new(MessageType::MethodReturn).reply_to(reply_to)
83    }
84
85    /// Create a message of type [`MessageType::Error`].
86    pub fn error<'e: 'a, E>(reply_to: &MessageHeader<'_>, name: E) -> Result<Self>
87    where
88        E: TryInto<ErrorName<'e>>,
89        E::Error: Into<Error>,
90    {
91        Self::new(MessageType::Error)
92            .error_name(name)?
93            .reply_to(reply_to)
94    }
95
96    /// Add flags to the message.
97    ///
98    /// See [`MessageFlags`] documentation for the meaning of the flags.
99    ///
100    /// The function will return an error if invalid flags are given for the message type.
101    pub fn with_flags(mut self, flag: MessageFlags) -> Result<Self> {
102        if self.header.message_type()? != MessageType::MethodCall
103            && BitFlags::from_flag(flag).contains(MessageFlags::NoReplyExpected)
104        {
105            return Err(Error::InvalidField);
106        }
107        let flags = self.header.primary().flags() | flag;
108        self.header.primary_mut().set_flags(flags);
109        Ok(self)
110    }
111
112    /// Set the unique name of the sending connection.
113    pub fn sender<'s: 'a, S>(mut self, sender: S) -> Result<Self>
114    where
115        S: TryInto<UniqueName<'s>>,
116        S::Error: Into<Error>,
117    {
118        self.header
119            .fields_mut()
120            .replace(MessageField::Sender(sender.try_into().map_err(Into::into)?));
121        Ok(self)
122    }
123
124    /// Set the object to send a call to, or the object a signal is emitted from.
125    pub fn path<'p: 'a, P>(mut self, path: P) -> Result<Self>
126    where
127        P: TryInto<ObjectPath<'p>>,
128        P::Error: Into<Error>,
129    {
130        self.header
131            .fields_mut()
132            .replace(MessageField::Path(path.try_into().map_err(Into::into)?));
133        Ok(self)
134    }
135
136    /// Set the interface to invoke a method call on, or that a signal is emitted from.
137    pub fn interface<'i: 'a, I>(mut self, interface: I) -> Result<Self>
138    where
139        I: TryInto<InterfaceName<'i>>,
140        I::Error: Into<Error>,
141    {
142        self.header.fields_mut().replace(MessageField::Interface(
143            interface.try_into().map_err(Into::into)?,
144        ));
145        Ok(self)
146    }
147
148    /// Set the member, either the method name or signal name.
149    pub fn member<'m: 'a, M>(mut self, member: M) -> Result<Self>
150    where
151        M: TryInto<MemberName<'m>>,
152        M::Error: Into<Error>,
153    {
154        self.header
155            .fields_mut()
156            .replace(MessageField::Member(member.try_into().map_err(Into::into)?));
157        Ok(self)
158    }
159
160    fn error_name<'e: 'a, E>(mut self, error: E) -> Result<Self>
161    where
162        E: TryInto<ErrorName<'e>>,
163        E::Error: Into<Error>,
164    {
165        self.header.fields_mut().replace(MessageField::ErrorName(
166            error.try_into().map_err(Into::into)?,
167        ));
168        Ok(self)
169    }
170
171    /// Set the name of the connection this message is intended for.
172    pub fn destination<'d: 'a, D>(mut self, destination: D) -> Result<Self>
173    where
174        D: TryInto<BusName<'d>>,
175        D::Error: Into<Error>,
176    {
177        self.header.fields_mut().replace(MessageField::Destination(
178            destination.try_into().map_err(Into::into)?,
179        ));
180        Ok(self)
181    }
182
183    fn reply_to(mut self, reply_to: &MessageHeader<'_>) -> Result<Self> {
184        let serial = reply_to.primary().serial_num().ok_or(Error::MissingField)?;
185        self.header
186            .fields_mut()
187            .replace(MessageField::ReplySerial(*serial));
188
189        if let Some(sender) = reply_to.sender()? {
190            self.destination(sender.to_owned())
191        } else {
192            Ok(self)
193        }
194    }
195
196    /// Build the [`Message`] with the given body.
197    ///
198    /// You may pass `()` as the body if the message has no body.
199    ///
200    /// The caller is currently required to ensure that the resulting message contains the headers
201    /// as compliant with the [specification]. Additional checks may be added to this builder over
202    /// time as needed.
203    ///
204    /// [specification]:
205    /// https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-header-fields
206    pub fn build<B>(self, body: &B) -> Result<Message>
207    where
208        B: serde::ser::Serialize + DynamicType,
209    {
210        let ctxt = dbus_context!(0);
211
212        // Note: this iterates the body twice, but we prefer efficient handling of large messages
213        // to efficient handling of ones that are complex to serialize.
214        #[cfg(unix)]
215        let (body_len, fds_len) = zvariant::serialized_size_fds(ctxt, body)?;
216        #[cfg(not(unix))]
217        let body_len = zvariant::serialized_size(ctxt, body)?;
218
219        let signature = body.dynamic_signature();
220
221        self.build_generic(
222            signature,
223            body_len,
224            move |cursor| {
225                #[cfg(unix)]
226                {
227                    let (_, fds) = zvariant::to_writer_fds(cursor, ctxt, body)?;
228                    Ok::<Vec<RawFd>, Error>(fds)
229                }
230                #[cfg(not(unix))]
231                {
232                    zvariant::to_writer(cursor, ctxt, body)?;
233                    Ok::<(), Error>(())
234                }
235            },
236            #[cfg(unix)]
237            fds_len,
238        )
239    }
240
241    /// Create a new message from a raw slice of bytes to populate the body with, rather than by
242    /// serializing a value. The message body will be the exact bytes.
243    ///
244    /// # Safety
245    ///
246    /// This method is unsafe because it can be used to build an invalid message.
247    pub unsafe fn build_raw_body<'b, S>(
248        self,
249        body_bytes: &[u8],
250        signature: S,
251        #[cfg(unix)] fds: Vec<RawFd>,
252    ) -> Result<Message>
253    where
254        S: TryInto<Signature<'b>>,
255        S::Error: Into<Error>,
256    {
257        let signature: Signature<'b> = signature.try_into().map_err(Into::into)?;
258        #[cfg(unix)]
259        let fds_len = fds.len();
260
261        self.build_generic(
262            signature,
263            body_bytes.len(),
264            move |cursor: &mut Cursor<&mut Vec<u8>>| {
265                cursor.write_all(body_bytes)?;
266
267                #[cfg(unix)]
268                return Ok::<Vec<RawFd>, Error>(fds);
269
270                #[cfg(not(unix))]
271                return Ok::<(), Error>(());
272            },
273            #[cfg(unix)]
274            fds_len,
275        )
276    }
277
278    fn build_generic<WriteFunc>(
279        self,
280        mut signature: Signature<'_>,
281        body_len: usize,
282        write_body: WriteFunc,
283        #[cfg(unix)] fds_len: usize,
284    ) -> Result<Message>
285    where
286        WriteFunc: FnOnce(&mut Cursor<&mut Vec<u8>>) -> Result<BuildGenericResult>,
287    {
288        let ctxt = dbus_context!(0);
289        let mut header = self.header;
290
291        if !signature.is_empty() {
292            if signature.starts_with(zvariant::STRUCT_SIG_START_STR) {
293                // Remove leading and trailing STRUCT delimiters
294                signature = signature.slice(1..signature.len() - 1);
295            }
296            header.fields_mut().add(MessageField::Signature(signature));
297        }
298
299        let body_len_u32 = body_len.try_into().map_err(|_| Error::ExcessData)?;
300        header.primary_mut().set_body_len(body_len_u32);
301
302        #[cfg(unix)]
303        {
304            let fds_len_u32 = fds_len.try_into().map_err(|_| Error::ExcessData)?;
305            if fds_len != 0 {
306                header.fields_mut().add(MessageField::UnixFDs(fds_len_u32));
307            }
308        }
309
310        let hdr_len = zvariant::serialized_size(ctxt, &header)?;
311        // We need to align the body to 8-byte boundary.
312        let body_padding = padding_for_8_bytes(hdr_len);
313        let body_offset = hdr_len + body_padding;
314        let total_len = body_offset + body_len;
315        if total_len > MAX_MESSAGE_SIZE {
316            return Err(Error::ExcessData);
317        }
318        let mut bytes: Vec<u8> = Vec::with_capacity(total_len);
319        let mut cursor = Cursor::new(&mut bytes);
320
321        zvariant::to_writer(&mut cursor, ctxt, &header)?;
322        for _ in 0..body_padding {
323            cursor.write_all(&[0u8])?;
324        }
325        #[cfg(unix)]
326        let fds = write_body(&mut cursor)?;
327        #[cfg(not(unix))]
328        write_body(&mut cursor)?;
329
330        let primary_header = header.into_primary();
331        let header: MessageHeader<'_> = zvariant::from_slice(&bytes, ctxt)?;
332        let quick_fields = QuickMessageFields::new(&bytes, &header)?;
333
334        Ok(Message {
335            primary_header,
336            quick_fields,
337            bytes,
338            body_offset,
339            #[cfg(unix)]
340            fds: Arc::new(RwLock::new(Fds::Raw(fds))),
341            recv_seq: MessageSequence::default(),
342        })
343    }
344}
345
346impl<'m> From<MessageHeader<'m>> for MessageBuilder<'m> {
347    fn from(mut header: MessageHeader<'m>) -> Self {
348        // Signature and Fds are added by body* methods.
349        let fields = header.fields_mut();
350        fields.remove(MessageFieldCode::Signature);
351        fields.remove(MessageFieldCode::UnixFDs);
352
353        Self { header }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::MessageBuilder;
360    use crate::Error;
361    use test_log::test;
362
363    #[test]
364    fn test_raw() -> Result<(), Error> {
365        let raw_body: &[u8] = &[16, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0];
366        let message_builder = MessageBuilder::signal("/", "test.test", "test")?;
367        let message = unsafe {
368            message_builder.build_raw_body(
369                raw_body,
370                "ai",
371                #[cfg(unix)]
372                vec![],
373            )?
374        };
375
376        let output: Vec<i32> = message.body()?;
377        assert_eq!(output, vec![1, 2, 3, 4]);
378
379        Ok(())
380    }
381}