zbus/
message_fields.rs

1use serde::{Deserialize, Serialize};
2use static_assertions::assert_impl_all;
3use std::convert::{TryFrom, TryInto};
4use zbus_names::{InterfaceName, MemberName};
5use zvariant::{ObjectPath, Type};
6
7use crate::{Message, MessageField, MessageFieldCode, MessageHeader, Result};
8
9// It's actually 10 (and even not that) but let's round it to next 8-byte alignment
10const MAX_FIELDS_IN_MESSAGE: usize = 16;
11
12/// A collection of [`MessageField`] instances.
13///
14/// [`MessageField`]: enum.MessageField.html
15#[derive(Debug, Clone, Serialize, Deserialize, Type)]
16pub struct MessageFields<'m>(#[serde(borrow)] Vec<MessageField<'m>>);
17
18assert_impl_all!(MessageFields<'_>: Send, Sync, Unpin);
19
20impl<'m> MessageFields<'m> {
21    /// Creates an empty collection of fields.
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    /// Appends a [`MessageField`] to the collection of fields in the message.
27    ///
28    /// [`MessageField`]: enum.MessageField.html
29    pub fn add<'f: 'm>(&mut self, field: MessageField<'f>) {
30        self.0.push(field);
31    }
32
33    /// Replaces a [`MessageField`] from the collection of fields with one with the same code,
34    /// returning the old value if present.
35    ///
36    /// [`MessageField`]: enum.MessageField.html
37    pub fn replace<'f: 'm>(&mut self, field: MessageField<'f>) -> Option<MessageField<'m>> {
38        let code = field.code();
39        if let Some(found) = self.0.iter_mut().find(|f| f.code() == code) {
40            return Some(std::mem::replace(found, field));
41        }
42        self.add(field);
43        None
44    }
45
46    /// Returns a slice with all the [`MessageField`] in the message.
47    ///
48    /// [`MessageField`]: enum.MessageField.html
49    pub fn get(&self) -> &[MessageField<'m>] {
50        &self.0
51    }
52
53    /// Gets a reference to a specific [`MessageField`] by its code.
54    ///
55    /// Returns `None` if the message has no such field.
56    ///
57    /// [`MessageField`]: enum.MessageField.html
58    pub fn get_field(&self, code: MessageFieldCode) -> Option<&MessageField<'m>> {
59        self.0.iter().find(|f| f.code() == code)
60    }
61
62    /// Consumes the `MessageFields` and returns a specific [`MessageField`] by its code.
63    ///
64    /// Returns `None` if the message has no such field.
65    ///
66    /// [`MessageField`]: enum.MessageField.html
67    pub fn into_field(self, code: MessageFieldCode) -> Option<MessageField<'m>> {
68        self.0.into_iter().find(|f| f.code() == code)
69    }
70
71    /// Remove the field matching the `code`.
72    ///
73    /// Returns `true` if a field was found and removed, `false` otherwise.
74    pub(crate) fn remove(&mut self, code: MessageFieldCode) -> bool {
75        match self.0.iter().enumerate().find(|(_, f)| f.code() == code) {
76            Some((i, _)) => {
77                self.0.remove(i);
78
79                true
80            }
81            None => false,
82        }
83    }
84}
85
86/// A byte range of a field in a Message, used in [`QuickMessageFields`].
87///
88/// Some invalid encodings (end = 0) are used to indicate "not cached" and "not present".
89#[derive(Debug, Default, Clone, Copy)]
90pub(crate) struct FieldPos {
91    start: u32,
92    end: u32,
93}
94
95impl FieldPos {
96    pub fn new_not_present() -> Self {
97        Self { start: 1, end: 0 }
98    }
99
100    pub fn build(msg_buf: &[u8], field_buf: &str) -> Option<Self> {
101        let buf_start = msg_buf.as_ptr() as usize;
102        let field_start = field_buf.as_ptr() as usize;
103        let offset = field_start.checked_sub(buf_start)?;
104        if offset <= msg_buf.len() && offset + field_buf.len() <= msg_buf.len() {
105            Some(Self {
106                start: offset.try_into().ok()?,
107                end: (offset + field_buf.len()).try_into().ok()?,
108            })
109        } else {
110            None
111        }
112    }
113
114    pub fn new<T>(msg_buf: &[u8], field: Option<&T>) -> Self
115    where
116        T: std::ops::Deref<Target = str>,
117    {
118        field
119            .and_then(|f| Self::build(msg_buf, f.deref()))
120            .unwrap_or_else(Self::new_not_present)
121    }
122
123    /// Reassemble a previously cached field.
124    ///
125    /// **NOTE**: The caller must ensure that the `msg_buff` is the same one `build` was called for.
126    /// Otherwise, you'll get a panic.
127    pub fn read<'m, T>(&self, msg_buf: &'m [u8]) -> Option<T>
128    where
129        T: TryFrom<&'m str>,
130        T::Error: std::fmt::Debug,
131    {
132        match self {
133            Self {
134                start: 0..=1,
135                end: 0,
136            } => None,
137            Self { start, end } => {
138                let s = std::str::from_utf8(&msg_buf[(*start as usize)..(*end as usize)])
139                    .expect("Invalid utf8 when reconstructing string");
140                // We already check the fields during the construction of `Self`.
141                T::try_from(s)
142                    .map(Some)
143                    .expect("Invalid field reconstruction")
144            }
145        }
146    }
147}
148
149/// A cache of some commonly-used fields of the header of a Message.
150#[derive(Debug, Default, Copy, Clone)]
151pub(crate) struct QuickMessageFields {
152    path: FieldPos,
153    interface: FieldPos,
154    member: FieldPos,
155    reply_serial: Option<u32>,
156}
157
158impl QuickMessageFields {
159    pub fn new(buf: &[u8], header: &MessageHeader<'_>) -> Result<Self> {
160        Ok(Self {
161            path: FieldPos::new(buf, header.path()?),
162            interface: FieldPos::new(buf, header.interface()?),
163            member: FieldPos::new(buf, header.member()?),
164            reply_serial: header.reply_serial()?,
165        })
166    }
167
168    pub fn path<'m>(&self, msg: &'m Message) -> Option<ObjectPath<'m>> {
169        self.path.read(msg.as_bytes())
170    }
171
172    pub fn interface<'m>(&self, msg: &'m Message) -> Option<InterfaceName<'m>> {
173        self.interface.read(msg.as_bytes())
174    }
175
176    pub fn member<'m>(&self, msg: &'m Message) -> Option<MemberName<'m>> {
177        self.member.read(msg.as_bytes())
178    }
179
180    pub fn reply_serial(&self) -> Option<u32> {
181        self.reply_serial
182    }
183}
184
185impl<'m> Default for MessageFields<'m> {
186    fn default() -> Self {
187        Self(Vec::with_capacity(MAX_FIELDS_IN_MESSAGE))
188    }
189}
190
191impl<'m> std::ops::Deref for MessageFields<'m> {
192    type Target = [MessageField<'m>];
193
194    fn deref(&self) -> &Self::Target {
195        self.get()
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::{MessageField, MessageFields};
202
203    #[test]
204    fn test() {
205        let mut mf = MessageFields::new();
206        assert_eq!(mf.len(), 0);
207        mf.add(MessageField::ReplySerial(42));
208        assert_eq!(mf.len(), 1);
209        mf.add(MessageField::ReplySerial(43));
210        assert_eq!(mf.len(), 2);
211
212        let mut mf = MessageFields::new();
213        assert_eq!(mf.len(), 0);
214        mf.replace(MessageField::ReplySerial(42));
215        assert_eq!(mf.len(), 1);
216        mf.replace(MessageField::ReplySerial(43));
217        assert_eq!(mf.len(), 1);
218    }
219}