zvariant_derive/
dict.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote, ToTokens};
3use syn::{punctuated::Punctuated, spanned::Spanned, Data, DeriveInput, Error, Field};
4use zvariant_utils::{case, macros};
5
6use crate::utils::*;
7
8pub fn expand_type_derive(input: DeriveInput) -> Result<TokenStream, Error> {
9    let name = match input.data {
10        Data::Struct(_) => input.ident,
11        _ => return Err(Error::new(input.span(), "only structs supported")),
12    };
13
14    let zv = zvariant_path();
15    let generics = input.generics;
16    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
17
18    Ok(quote! {
19        impl #impl_generics #zv::Type for #name #ty_generics
20        #where_clause
21        {
22            fn signature() -> #zv::Signature<'static> {
23                #zv::Signature::from_static_str_unchecked("a{sv}")
24            }
25        }
26    })
27}
28
29fn dict_name_for_field(
30    f: &Field,
31    rename_attr: Option<String>,
32    rename_all_attr: Option<&str>,
33) -> Result<String, Error> {
34    if let Some(name) = rename_attr {
35        Ok(name)
36    } else {
37        let ident = f.ident.as_ref().unwrap().to_string();
38
39        match rename_all_attr {
40            Some("lowercase") => Ok(ident.to_ascii_lowercase()),
41            Some("UPPERCASE") => Ok(ident.to_ascii_uppercase()),
42            Some("PascalCase") => Ok(case::pascal_or_camel_case(&ident, true)),
43            Some("camelCase") => Ok(case::pascal_or_camel_case(&ident, false)),
44            Some("snake_case") => Ok(case::snake_case(&ident)),
45            None => Ok(ident),
46            Some(other) => Err(Error::new(
47                f.span(),
48                format!("invalid `rename_all` attribute value {other}"),
49            )),
50        }
51    }
52}
53
54pub fn expand_serialize_derive(input: DeriveInput) -> Result<TokenStream, Error> {
55    let (name, data) = match input.data {
56        Data::Struct(data) => (input.ident, data),
57        _ => return Err(Error::new(input.span(), "only structs supported")),
58    };
59
60    let StructAttributes { rename_all, .. } = StructAttributes::parse(&input.attrs)?;
61
62    let zv = zvariant_path();
63    let mut entries = quote! {};
64    let mut num_entries: usize = 0;
65
66    for f in &data.fields {
67        let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?;
68
69        let name = &f.ident;
70        let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?;
71
72        let is_option = macros::ty_is_option(&f.ty);
73
74        let e = if is_option {
75            quote! {
76                if self.#name.is_some() {
77                    map.serialize_entry(#dict_name, &#zv::SerializeValue(self.#name.as_ref().unwrap()))?;
78                }
79            }
80        } else {
81            quote! {
82                map.serialize_entry(#dict_name, &#zv::SerializeValue(&self.#name))?;
83            }
84        };
85
86        entries.extend(e);
87        num_entries += 1;
88    }
89
90    let generics = input.generics;
91    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
92
93    let num_entries = num_entries.to_token_stream();
94    Ok(quote! {
95        #[allow(deprecated)]
96        impl #impl_generics #zv::export::serde::ser::Serialize for #name #ty_generics
97        #where_clause
98        {
99            fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
100            where
101                S: #zv::export::serde::ser::Serializer,
102            {
103                use #zv::export::serde::ser::SerializeMap;
104
105                // zbus doesn't care about number of entries (it would need bytes instead)
106                let mut map = serializer.serialize_map(::std::option::Option::Some(#num_entries))?;
107                #entries
108                map.end()
109            }
110        }
111    })
112}
113
114pub fn expand_deserialize_derive(input: DeriveInput) -> Result<TokenStream, Error> {
115    let (name, data) = match input.data {
116        Data::Struct(data) => (input.ident, data),
117        _ => return Err(Error::new(input.span(), "only structs supported")),
118    };
119
120    let StructAttributes {
121        rename_all,
122        deny_unknown_fields,
123        ..
124    } = StructAttributes::parse(&input.attrs)?;
125
126    let visitor = format_ident!("{}Visitor", name);
127    let zv = zvariant_path();
128    let mut fields = Vec::new();
129    let mut req_fields = Vec::new();
130    let mut dict_names = Vec::new();
131    let mut entries = Vec::new();
132
133    for f in &data.fields {
134        let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?;
135
136        let name = &f.ident;
137        let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?;
138
139        let is_option = macros::ty_is_option(&f.ty);
140
141        entries.push(quote! {
142            #dict_name => {
143                // FIXME: add an option about strict parsing (instead of silently skipping the field)
144                #name = access.next_value::<#zv::DeserializeValue<_>>().map(|v| v.0).ok();
145            }
146        });
147
148        dict_names.push(dict_name);
149        fields.push(name);
150
151        if !is_option {
152            req_fields.push(name);
153        }
154    }
155
156    let fallback = if deny_unknown_fields {
157        quote! {
158            field => {
159                return ::std::result::Result::Err(
160                    <M::Error as #zv::export::serde::de::Error>::unknown_field(
161                        field,
162                        &[#(#dict_names),*],
163                    ),
164                );
165            }
166        }
167    } else {
168        quote! {
169            unknown => {
170                let _ = access.next_value::<#zv::Value>();
171            }
172        }
173    };
174    entries.push(fallback);
175
176    let (_, ty_generics, _) = input.generics.split_for_impl();
177    let mut generics = input.generics.clone();
178    let def = syn::LifetimeDef {
179        attrs: Vec::new(),
180        lifetime: syn::Lifetime::new("'de", Span::call_site()),
181        colon_token: None,
182        bounds: Punctuated::new(),
183    };
184    generics.params = Some(syn::GenericParam::Lifetime(def))
185        .into_iter()
186        .chain(generics.params)
187        .collect();
188
189    let (impl_generics, _, where_clause) = generics.split_for_impl();
190
191    Ok(quote! {
192        #[allow(deprecated)]
193        impl #impl_generics #zv::export::serde::de::Deserialize<'de> for #name #ty_generics
194        #where_clause
195        {
196            fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
197            where
198                D: #zv::export::serde::de::Deserializer<'de>,
199            {
200                struct #visitor #ty_generics(::std::marker::PhantomData<#name #ty_generics>);
201
202                impl #impl_generics #zv::export::serde::de::Visitor<'de> for #visitor #ty_generics {
203                    type Value = #name #ty_generics;
204
205                    fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
206                        formatter.write_str("a dictionary")
207                    }
208
209                    fn visit_map<M>(
210                        self,
211                        mut access: M,
212                    ) -> ::std::result::Result<Self::Value, M::Error>
213                    where
214                        M: #zv::export::serde::de::MapAccess<'de>,
215                    {
216                        #( let mut #fields = ::std::default::Default::default(); )*
217
218                        // does not check duplicated fields, since those shouldn't exist in stream
219                        while let ::std::option::Option::Some(key) = access.next_key::<&str>()? {
220                            match key {
221                                #(#entries)*
222                            }
223                        }
224
225                        #(let #req_fields = if let ::std::option::Option::Some(val) = #req_fields {
226                            val
227                        } else {
228                            return ::std::result::Result::Err(
229                                <M::Error as #zv::export::serde::de::Error>::missing_field(
230                                    ::std::stringify!(#req_fields),
231                                ),
232                            );
233                        };)*
234
235                        ::std::result::Result::Ok(#name { #(#fields),* })
236                    }
237                }
238
239
240                deserializer.deserialize_map(#visitor(::std::marker::PhantomData))
241            }
242        }
243    })
244}