zbus_macros/
error.rs

1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{spanned::Spanned, Data, DeriveInput, Error, Fields, Ident, Variant};
4use zvariant_utils::def_attrs;
5
6// FIXME: The list name should once be "zbus" instead of "dbus_error" (like in serde).
7def_attrs! {
8    crate dbus_error;
9
10    pub StructAttributes("struct") {
11        prefix str,
12        impl_display bool
13    };
14
15    pub VariantAttributes("enum variant") {
16        name str,
17        zbus_error none
18    };
19}
20
21use crate::utils::*;
22
23pub fn expand_derive(input: DeriveInput) -> Result<TokenStream, Error> {
24    let StructAttributes {
25        prefix,
26        impl_display,
27    } = StructAttributes::parse(&input.attrs)?;
28    let prefix = prefix.unwrap_or_else(|| "org.freedesktop.DBus".to_string());
29    let generate_display = impl_display.unwrap_or(true);
30
31    let (_vis, name, _generics, data) = match input.data {
32        Data::Enum(data) => (input.vis, input.ident, input.generics, data),
33        _ => return Err(Error::new(input.span(), "only enums supported")),
34    };
35
36    let zbus = zbus_path();
37    let mut replies = quote! {};
38    let mut error_names = quote! {};
39    let mut error_descriptions = quote! {};
40    let mut error_converts = quote! {};
41
42    let mut zbus_error_variant = None;
43
44    for variant in data.variants {
45        let VariantAttributes { name, zbus_error } = VariantAttributes::parse(&variant.attrs)?;
46
47        let ident = &variant.ident;
48        let name = name.unwrap_or_else(|| ident.to_string());
49
50        let fqn = if !zbus_error {
51            format!("{prefix}.{name}")
52        } else {
53            // The ZBus error variant will always be a hardcoded string.
54            String::from("org.freedesktop.zbus.Error")
55        };
56
57        let error_name = quote! {
58            #zbus::names::ErrorName::from_static_str_unchecked(#fqn)
59        };
60        let e = match variant.fields {
61            Fields::Unit => quote! {
62                Self::#ident => #error_name,
63            },
64            Fields::Unnamed(_) => quote! {
65                Self::#ident(..) => #error_name,
66            },
67            Fields::Named(_) => quote! {
68                Self::#ident { .. } => #error_name,
69            },
70        };
71        error_names.extend(e);
72
73        if zbus_error {
74            if zbus_error_variant.is_some() {
75                panic!("More than 1 `zbus_error` variant found");
76            }
77
78            zbus_error_variant = Some(quote! { #ident });
79        }
80
81        // FIXME: this will error if the first field is not a string as per the dbus spec, but we
82        // may support other cases?
83        let e = match &variant.fields {
84            Fields::Unit => quote! {
85                Self::#ident => None,
86            },
87            Fields::Unnamed(_) => {
88                if zbus_error {
89                    quote! {
90                        Self::#ident(#zbus::Error::MethodError(_, desc, _)) => desc.as_deref(),
91                        Self::#ident(_) => None,
92                    }
93                } else {
94                    quote! {
95                        Self::#ident(desc, ..) => Some(&desc),
96                    }
97                }
98            }
99            Fields::Named(n) => {
100                let f = &n
101                    .named
102                    .first()
103                    .ok_or_else(|| Error::new(n.span(), "expected at least one field"))?
104                    .ident;
105                quote! {
106                    Self::#ident { #f, } => Some(#f),
107                }
108            }
109        };
110        error_descriptions.extend(e);
111
112        // The conversion for zbus_error variant is handled separately/explicitly.
113        if !zbus_error {
114            // FIXME: deserialize msg to error field instead, to support variable args
115            let e = match &variant.fields {
116                Fields::Unit => quote! {
117                    #fqn => Self::#ident,
118                },
119                Fields::Unnamed(_) => quote! {
120                    #fqn => { Self::#ident(::std::clone::Clone::clone(desc).unwrap_or_default()) },
121                },
122                Fields::Named(n) => {
123                    let f = &n
124                        .named
125                        .first()
126                        .ok_or_else(|| Error::new(n.span(), "expected at least one field"))?
127                        .ident;
128                    quote! {
129                        #fqn => {
130                            let desc = ::std::clone::Clone::clone(desc).unwrap_or_default();
131
132                            Self::#ident { #f: desc }
133                        }
134                    }
135                }
136            };
137            error_converts.extend(e);
138        }
139
140        let r = gen_reply_for_variant(&variant, zbus_error)?;
141        replies.extend(r);
142    }
143
144    let from_zbus_error_impl = zbus_error_variant
145        .map(|ident| {
146            quote! {
147                impl ::std::convert::From<#zbus::Error> for #name {
148                    fn from(value: #zbus::Error) -> #name {
149                        if let #zbus::Error::MethodError(name, desc, _) = &value {
150                            match name.as_str() {
151                                #error_converts
152                                _ => Self::#ident(value),
153                            }
154                        } else {
155                            Self::#ident(value)
156                        }
157                    }
158                }
159            }
160        })
161        .unwrap_or_default();
162
163    let display_impl = if generate_display {
164        quote! {
165            impl ::std::fmt::Display for #name {
166                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
167                    let name = #zbus::DBusError::name(self);
168                    let description = #zbus::DBusError::description(self).unwrap_or("no description");
169                    ::std::write!(f, "{}: {}", name, description)
170                }
171            }
172        }
173    } else {
174        quote! {}
175    };
176
177    Ok(quote! {
178        impl #zbus::DBusError for #name {
179            fn name(&self) -> #zbus::names::ErrorName {
180                match self {
181                    #error_names
182                }
183            }
184
185            fn description(&self) -> Option<&str> {
186                match self {
187                    #error_descriptions
188                }
189            }
190
191            fn create_reply(&self, call: &#zbus::MessageHeader) -> #zbus::Result<#zbus::Message> {
192                let name = self.name();
193                match self {
194                    #replies
195                }
196            }
197        }
198
199        #display_impl
200
201        impl ::std::error::Error for #name {}
202
203        #from_zbus_error_impl
204    })
205}
206
207fn gen_reply_for_variant(
208    variant: &Variant,
209    zbus_error_variant: bool,
210) -> Result<TokenStream, Error> {
211    let zbus = zbus_path();
212    let ident = &variant.ident;
213    match &variant.fields {
214        Fields::Unit => Ok(quote! {
215            Self::#ident => #zbus::MessageBuilder::error(call, name)?.build(&()),
216        }),
217        Fields::Unnamed(f) => {
218            // Name the unnamed fields as the number of the field with an 'f' in front.
219            let in_fields = (0..f.unnamed.len())
220                .map(|n| Ident::new(&format!("f{n}"), ident.span()).to_token_stream())
221                .collect::<Vec<_>>();
222            let out_fields = if zbus_error_variant {
223                let error_field = in_fields.first().ok_or_else(|| {
224                    Error::new(
225                        ident.span(),
226                        "expected at least one field for zbus_error variant",
227                    )
228                })?;
229                vec![quote! {
230                    match #error_field {
231                        #zbus::Error::MethodError(name, desc, _) => {
232                            ::std::clone::Clone::clone(desc)
233                        }
234                        _ => None,
235                    }
236                    .unwrap_or_else(|| ::std::string::ToString::to_string(#error_field))
237                }]
238            } else {
239                // FIXME: Workaround for https://github.com/rust-lang/rust-clippy/issues/10577
240                #[allow(clippy::redundant_clone)]
241                in_fields.clone()
242            };
243
244            Ok(quote! {
245                Self::#ident(#(#in_fields),*) => #zbus::MessageBuilder::error(call, name)?.build(&(#(#out_fields),*)),
246            })
247        }
248        Fields::Named(f) => {
249            let fields = f.named.iter().map(|v| v.ident.as_ref()).collect::<Vec<_>>();
250            Ok(quote! {
251                Self::#ident { #(#fields),* } => #zbus::MessageBuilder::error(call, name)?.build(&(#(#fields),*)),
252            })
253        }
254    }
255}