use std::{
convert::TryInto,
io::{Cursor, Write},
};
#[cfg(unix)]
use crate::Fds;
#[cfg(unix)]
use std::{
os::unix::io::RawFd,
sync::{Arc, RwLock},
};
use enumflags2::BitFlags;
use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
use crate::{
utils::padding_for_8_bytes,
zvariant::{DynamicType, EncodingContext, ObjectPath, Signature},
Error, Message, MessageField, MessageFieldCode, MessageFields, MessageFlags, MessageHeader,
MessagePrimaryHeader, MessageSequence, MessageType, QuickMessageFields, Result,
MAX_MESSAGE_SIZE,
};
#[cfg(unix)]
type BuildGenericResult = Vec<RawFd>;
#[cfg(not(unix))]
type BuildGenericResult = ();
macro_rules! dbus_context {
($n_bytes_before: expr) => {
EncodingContext::<byteorder::NativeEndian>::new_dbus($n_bytes_before)
};
}
#[derive(Debug, Clone)]
pub struct MessageBuilder<'a> {
header: MessageHeader<'a>,
}
impl<'a> MessageBuilder<'a> {
fn new(msg_type: MessageType) -> Self {
let primary = MessagePrimaryHeader::new(msg_type, 0);
let fields = MessageFields::new();
let header = MessageHeader::new(primary, fields);
Self { header }
}
pub fn method_call<'p: 'a, 'm: 'a, P, M>(path: P, method_name: M) -> Result<Self>
where
P: TryInto<ObjectPath<'p>>,
M: TryInto<MemberName<'m>>,
P::Error: Into<Error>,
M::Error: Into<Error>,
{
Self::new(MessageType::MethodCall)
.path(path)?
.member(method_name)
}
pub fn signal<'p: 'a, 'i: 'a, 'm: 'a, P, I, M>(path: P, interface: I, name: M) -> Result<Self>
where
P: TryInto<ObjectPath<'p>>,
I: TryInto<InterfaceName<'i>>,
M: TryInto<MemberName<'m>>,
P::Error: Into<Error>,
I::Error: Into<Error>,
M::Error: Into<Error>,
{
Self::new(MessageType::Signal)
.path(path)?
.interface(interface)?
.member(name)
}
pub fn method_return(reply_to: &MessageHeader<'_>) -> Result<Self> {
Self::new(MessageType::MethodReturn).reply_to(reply_to)
}
pub fn error<'e: 'a, E>(reply_to: &MessageHeader<'_>, name: E) -> Result<Self>
where
E: TryInto<ErrorName<'e>>,
E::Error: Into<Error>,
{
Self::new(MessageType::Error)
.error_name(name)?
.reply_to(reply_to)
}
pub fn with_flags(mut self, flag: MessageFlags) -> Result<Self> {
if self.header.message_type()? != MessageType::MethodCall
&& BitFlags::from_flag(flag).contains(MessageFlags::NoReplyExpected)
{
return Err(Error::InvalidField);
}
let flags = self.header.primary().flags() | flag;
self.header.primary_mut().set_flags(flags);
Ok(self)
}
pub fn sender<'s: 'a, S>(mut self, sender: S) -> Result<Self>
where
S: TryInto<UniqueName<'s>>,
S::Error: Into<Error>,
{
self.header
.fields_mut()
.replace(MessageField::Sender(sender.try_into().map_err(Into::into)?));
Ok(self)
}
pub fn path<'p: 'a, P>(mut self, path: P) -> Result<Self>
where
P: TryInto<ObjectPath<'p>>,
P::Error: Into<Error>,
{
self.header
.fields_mut()
.replace(MessageField::Path(path.try_into().map_err(Into::into)?));
Ok(self)
}
pub fn interface<'i: 'a, I>(mut self, interface: I) -> Result<Self>
where
I: TryInto<InterfaceName<'i>>,
I::Error: Into<Error>,
{
self.header.fields_mut().replace(MessageField::Interface(
interface.try_into().map_err(Into::into)?,
));
Ok(self)
}
pub fn member<'m: 'a, M>(mut self, member: M) -> Result<Self>
where
M: TryInto<MemberName<'m>>,
M::Error: Into<Error>,
{
self.header
.fields_mut()
.replace(MessageField::Member(member.try_into().map_err(Into::into)?));
Ok(self)
}
fn error_name<'e: 'a, E>(mut self, error: E) -> Result<Self>
where
E: TryInto<ErrorName<'e>>,
E::Error: Into<Error>,
{
self.header.fields_mut().replace(MessageField::ErrorName(
error.try_into().map_err(Into::into)?,
));
Ok(self)
}
pub fn destination<'d: 'a, D>(mut self, destination: D) -> Result<Self>
where
D: TryInto<BusName<'d>>,
D::Error: Into<Error>,
{
self.header.fields_mut().replace(MessageField::Destination(
destination.try_into().map_err(Into::into)?,
));
Ok(self)
}
fn reply_to(mut self, reply_to: &MessageHeader<'_>) -> Result<Self> {
let serial = reply_to.primary().serial_num().ok_or(Error::MissingField)?;
self.header
.fields_mut()
.replace(MessageField::ReplySerial(*serial));
if let Some(sender) = reply_to.sender()? {
self.destination(sender.to_owned())
} else {
Ok(self)
}
}
pub fn build<B>(self, body: &B) -> Result<Message>
where
B: serde::ser::Serialize + DynamicType,
{
let ctxt = dbus_context!(0);
#[cfg(unix)]
let (body_len, fds_len) = zvariant::serialized_size_fds(ctxt, body)?;
#[cfg(not(unix))]
let body_len = zvariant::serialized_size(ctxt, body)?;
let signature = body.dynamic_signature();
self.build_generic(
signature,
body_len,
move |cursor| {
#[cfg(unix)]
{
let (_, fds) = zvariant::to_writer_fds(cursor, ctxt, body)?;
Ok::<Vec<RawFd>, Error>(fds)
}
#[cfg(not(unix))]
{
zvariant::to_writer(cursor, ctxt, body)?;
Ok::<(), Error>(())
}
},
#[cfg(unix)]
fds_len,
)
}
pub unsafe fn build_raw_body<'b, S>(
self,
body_bytes: &[u8],
signature: S,
#[cfg(unix)] fds: Vec<RawFd>,
) -> Result<Message>
where
S: TryInto<Signature<'b>>,
S::Error: Into<Error>,
{
let signature: Signature<'b> = signature.try_into().map_err(Into::into)?;
#[cfg(unix)]
let fds_len = fds.len();
self.build_generic(
signature,
body_bytes.len(),
move |cursor: &mut Cursor<&mut Vec<u8>>| {
cursor.write_all(body_bytes)?;
#[cfg(unix)]
return Ok::<Vec<RawFd>, Error>(fds);
#[cfg(not(unix))]
return Ok::<(), Error>(());
},
#[cfg(unix)]
fds_len,
)
}
fn build_generic<WriteFunc>(
self,
mut signature: Signature<'_>,
body_len: usize,
write_body: WriteFunc,
#[cfg(unix)] fds_len: usize,
) -> Result<Message>
where
WriteFunc: FnOnce(&mut Cursor<&mut Vec<u8>>) -> Result<BuildGenericResult>,
{
let ctxt = dbus_context!(0);
let mut header = self.header;
if !signature.is_empty() {
if signature.starts_with(zvariant::STRUCT_SIG_START_STR) {
signature = signature.slice(1..signature.len() - 1);
}
header.fields_mut().add(MessageField::Signature(signature));
}
let body_len_u32 = body_len.try_into().map_err(|_| Error::ExcessData)?;
header.primary_mut().set_body_len(body_len_u32);
#[cfg(unix)]
{
let fds_len_u32 = fds_len.try_into().map_err(|_| Error::ExcessData)?;
if fds_len != 0 {
header.fields_mut().add(MessageField::UnixFDs(fds_len_u32));
}
}
let hdr_len = zvariant::serialized_size(ctxt, &header)?;
let body_padding = padding_for_8_bytes(hdr_len);
let body_offset = hdr_len + body_padding;
let total_len = body_offset + body_len;
if total_len > MAX_MESSAGE_SIZE {
return Err(Error::ExcessData);
}
let mut bytes: Vec<u8> = Vec::with_capacity(total_len);
let mut cursor = Cursor::new(&mut bytes);
zvariant::to_writer(&mut cursor, ctxt, &header)?;
for _ in 0..body_padding {
cursor.write_all(&[0u8])?;
}
#[cfg(unix)]
let fds = write_body(&mut cursor)?;
#[cfg(not(unix))]
write_body(&mut cursor)?;
let primary_header = header.into_primary();
let header: MessageHeader<'_> = zvariant::from_slice(&bytes, ctxt)?;
let quick_fields = QuickMessageFields::new(&bytes, &header)?;
Ok(Message {
primary_header,
quick_fields,
bytes,
body_offset,
#[cfg(unix)]
fds: Arc::new(RwLock::new(Fds::Raw(fds))),
recv_seq: MessageSequence::default(),
})
}
}
impl<'m> From<MessageHeader<'m>> for MessageBuilder<'m> {
fn from(mut header: MessageHeader<'m>) -> Self {
let fields = header.fields_mut();
fields.remove(MessageFieldCode::Signature);
fields.remove(MessageFieldCode::UnixFDs);
Self { header }
}
}
#[cfg(test)]
mod tests {
use super::MessageBuilder;
use crate::Error;
use test_log::test;
#[test]
fn test_raw() -> Result<(), Error> {
let raw_body: &[u8] = &[16, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0];
let message_builder = MessageBuilder::signal("/", "test.test", "test")?;
let message = unsafe {
message_builder.build_raw_body(
raw_body,
"ai",
#[cfg(unix)]
vec![],
)?
};
let output: Vec<i32> = message.body()?;
assert_eq!(output, vec![1, 2, 3, 4]);
Ok(())
}
}