use std::convert::TryFrom;
use enumflags2::{bitflags, BitFlags};
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use static_assertions::assert_impl_all;
use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
use zvariant::{EncodingContext, ObjectPath, Signature, Type};
use crate::{Error, MessageField, MessageFieldCode, MessageFields};
pub(crate) const PRIMARY_HEADER_SIZE: usize = 12;
pub(crate) const MIN_MESSAGE_SIZE: usize = PRIMARY_HEADER_SIZE + 4;
pub(crate) const MAX_MESSAGE_SIZE: usize = 128 * 1024 * 1024; #[repr(u8)]
#[derive(Debug, Copy, Clone, Deserialize_repr, PartialEq, Eq, Serialize_repr, Type)]
pub enum EndianSig {
Big = b'B',
Little = b'l',
}
assert_impl_all!(EndianSig: Send, Sync, Unpin);
impl TryFrom<u8> for EndianSig {
type Error = Error;
fn try_from(val: u8) -> Result<EndianSig, Error> {
match val {
b'B' => Ok(EndianSig::Big),
b'l' => Ok(EndianSig::Little),
_ => Err(Error::IncorrectEndian),
}
}
}
#[cfg(target_endian = "big")]
pub const NATIVE_ENDIAN_SIG: EndianSig = EndianSig::Big;
#[cfg(target_endian = "little")]
pub const NATIVE_ENDIAN_SIG: EndianSig = EndianSig::Little;
#[repr(u8)]
#[derive(Debug, Copy, Clone, Deserialize_repr, PartialEq, Eq, Hash, Serialize_repr, Type)]
pub enum MessageType {
Invalid = 0,
MethodCall = 1,
MethodReturn = 2,
Error = 3,
Signal = 4,
}
assert_impl_all!(MessageType: Send, Sync, Unpin);
impl From<u8> for MessageType {
fn from(val: u8) -> MessageType {
match val {
1 => MessageType::MethodCall,
2 => MessageType::MethodReturn,
3 => MessageType::Error,
4 => MessageType::Signal,
_ => MessageType::Invalid,
}
}
}
#[bitflags]
#[repr(u8)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Type)]
pub enum MessageFlags {
NoReplyExpected = 0x1,
NoAutoStart = 0x2,
AllowInteractiveAuth = 0x4,
}
assert_impl_all!(MessageFlags: Send, Sync, Unpin);
#[derive(Clone, Debug)]
struct SerialNum(OnceCell<u32>);
impl zvariant::Type for SerialNum {
fn signature() -> Signature<'static> {
u32::signature()
}
}
impl Serialize for SerialNum {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0
.get()
.cloned()
.unwrap_or_default()
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for SerialNum {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(SerialNum(OnceCell::from(u32::deserialize(deserializer)?)))
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Type)]
pub struct MessagePrimaryHeader {
endian_sig: EndianSig,
msg_type: MessageType,
flags: BitFlags<MessageFlags>,
protocol_version: u8,
body_len: u32,
serial_num: SerialNum,
}
assert_impl_all!(MessagePrimaryHeader: Send, Sync, Unpin);
impl MessagePrimaryHeader {
pub fn new(msg_type: MessageType, body_len: u32) -> Self {
Self {
endian_sig: NATIVE_ENDIAN_SIG,
msg_type,
flags: BitFlags::empty(),
protocol_version: 1,
body_len,
serial_num: SerialNum(OnceCell::new()),
}
}
pub(crate) fn read(buf: &[u8]) -> Result<(MessagePrimaryHeader, u32), Error> {
let ctx = EncodingContext::<byteorder::NativeEndian>::new_dbus(0);
let primary_header = zvariant::from_slice(buf, ctx)?;
let fields_len = zvariant::from_slice(&buf[PRIMARY_HEADER_SIZE..], ctx)?;
Ok((primary_header, fields_len))
}
pub fn endian_sig(&self) -> EndianSig {
self.endian_sig
}
pub fn set_endian_sig(&mut self, sig: EndianSig) {
self.endian_sig = sig;
}
pub fn msg_type(&self) -> MessageType {
self.msg_type
}
pub fn set_msg_type(&mut self, msg_type: MessageType) {
self.msg_type = msg_type;
}
pub fn flags(&self) -> BitFlags<MessageFlags> {
self.flags
}
pub fn set_flags(&mut self, flags: BitFlags<MessageFlags>) {
self.flags = flags;
}
pub fn protocol_version(&self) -> u8 {
self.protocol_version
}
pub fn set_protocol_version(&mut self, version: u8) {
self.protocol_version = version;
}
pub fn body_len(&self) -> u32 {
self.body_len
}
pub fn set_body_len(&mut self, len: u32) {
self.body_len = len;
}
pub fn serial_num(&self) -> Option<&u32> {
self.serial_num.0.get()
}
pub(crate) fn serial_num_or_init<F>(&mut self, f: F) -> &u32
where
F: FnOnce() -> u32,
{
self.serial_num.0.get_or_init(f)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Type)]
pub struct MessageHeader<'m> {
primary: MessagePrimaryHeader,
#[serde(borrow)]
fields: MessageFields<'m>,
}
assert_impl_all!(MessageHeader<'_>: Send, Sync, Unpin);
macro_rules! get_field {
($self:ident, $kind:ident) => {
get_field!($self, $kind, (|v| v))
};
($self:ident, $kind:ident, $closure:tt) => {
#[allow(clippy::redundant_closure_call)]
match $self.fields().get_field(MessageFieldCode::$kind) {
Some(MessageField::$kind(value)) => Ok(Some($closure(value))),
Some(_) => Err(Error::InvalidField),
None => Ok(None),
}
};
}
macro_rules! get_field_u32 {
($self:ident, $kind:ident) => {
get_field!($self, $kind, (|v: &u32| *v))
};
}
impl<'m> MessageHeader<'m> {
pub fn new(primary: MessagePrimaryHeader, fields: MessageFields<'m>) -> Self {
Self { primary, fields }
}
pub fn primary(&self) -> &MessagePrimaryHeader {
&self.primary
}
pub fn primary_mut(&mut self) -> &mut MessagePrimaryHeader {
&mut self.primary
}
pub fn into_primary(self) -> MessagePrimaryHeader {
self.primary
}
pub fn fields<'s>(&'s self) -> &'s MessageFields<'m> {
&self.fields
}
pub fn fields_mut<'s>(&'s mut self) -> &'s mut MessageFields<'m> {
&mut self.fields
}
pub fn into_fields(self) -> MessageFields<'m> {
self.fields
}
pub fn message_type(&self) -> Result<MessageType, Error> {
Ok(self.primary().msg_type())
}
pub fn path<'s>(&'s self) -> Result<Option<&ObjectPath<'m>>, Error> {
get_field!(self, Path)
}
pub fn interface<'s>(&'s self) -> Result<Option<&InterfaceName<'m>>, Error> {
get_field!(self, Interface)
}
pub fn member<'s>(&'s self) -> Result<Option<&MemberName<'m>>, Error> {
get_field!(self, Member)
}
pub fn error_name<'s>(&'s self) -> Result<Option<&ErrorName<'m>>, Error> {
get_field!(self, ErrorName)
}
pub fn reply_serial(&self) -> Result<Option<u32>, Error> {
get_field_u32!(self, ReplySerial)
}
pub fn destination<'s>(&'s self) -> Result<Option<&BusName<'m>>, Error> {
get_field!(self, Destination)
}
pub fn sender<'s>(&'s self) -> Result<Option<&UniqueName<'m>>, Error> {
get_field!(self, Sender)
}
pub fn signature(&self) -> Result<Option<&Signature<'m>>, Error> {
get_field!(self, Signature)
}
pub fn unix_fds(&self) -> Result<Option<u32>, Error> {
get_field_u32!(self, UnixFDs)
}
}
#[cfg(test)]
mod tests {
use crate::{MessageField, MessageFields, MessageHeader, MessagePrimaryHeader, MessageType};
use std::{
convert::{TryFrom, TryInto},
error::Error,
result::Result,
};
use test_log::test;
use zbus_names::{InterfaceName, MemberName};
use zvariant::{ObjectPath, Signature};
#[test]
fn header() -> Result<(), Box<dyn Error>> {
let path = ObjectPath::try_from("/some/path")?;
let iface = InterfaceName::try_from("some.interface")?;
let member = MemberName::try_from("Member")?;
let mut f = MessageFields::new();
f.add(MessageField::Path(path.clone()));
f.add(MessageField::Interface(iface.clone()));
f.add(MessageField::Member(member.clone()));
f.add(MessageField::Sender(":1.84".try_into()?));
let h = MessageHeader::new(MessagePrimaryHeader::new(MessageType::Signal, 77), f);
assert_eq!(h.message_type()?, MessageType::Signal);
assert_eq!(h.path()?, Some(&path));
assert_eq!(h.interface()?, Some(&iface));
assert_eq!(h.member()?, Some(&member));
assert_eq!(h.error_name()?, None);
assert_eq!(h.destination()?, None);
assert_eq!(h.reply_serial()?, None);
assert_eq!(h.sender()?.unwrap(), ":1.84");
assert_eq!(h.signature()?, None);
assert_eq!(h.unix_fds()?, None);
let mut f = MessageFields::new();
f.add(MessageField::ErrorName("org.zbus.Error".try_into()?));
f.add(MessageField::Destination(":1.11".try_into()?));
f.add(MessageField::ReplySerial(88));
f.add(MessageField::Signature(Signature::from_str_unchecked(
"say",
)));
f.add(MessageField::UnixFDs(12));
let h = MessageHeader::new(MessagePrimaryHeader::new(MessageType::MethodReturn, 77), f);
assert_eq!(h.message_type()?, MessageType::MethodReturn);
assert_eq!(h.path()?, None);
assert_eq!(h.interface()?, None);
assert_eq!(h.member()?, None);
assert_eq!(h.error_name()?.unwrap(), "org.zbus.Error");
assert_eq!(h.destination()?.unwrap(), ":1.11");
assert_eq!(h.reply_serial()?, Some(88));
assert_eq!(h.sender()?, None);
assert_eq!(h.signature()?, Some(&Signature::from_str_unchecked("say")));
assert_eq!(h.unix_fds()?, Some(12));
Ok(())
}
}