zvariant_derive/
dict.rs
1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote, ToTokens};
3use syn::{punctuated::Punctuated, spanned::Spanned, Data, DeriveInput, Error, Field};
4use zvariant_utils::{case, macros};
5
6use crate::utils::*;
7
8pub fn expand_type_derive(input: DeriveInput) -> Result<TokenStream, Error> {
9 let name = match input.data {
10 Data::Struct(_) => input.ident,
11 _ => return Err(Error::new(input.span(), "only structs supported")),
12 };
13
14 let zv = zvariant_path();
15 let generics = input.generics;
16 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
17
18 Ok(quote! {
19 impl #impl_generics #zv::Type for #name #ty_generics
20 #where_clause
21 {
22 fn signature() -> #zv::Signature<'static> {
23 #zv::Signature::from_static_str_unchecked("a{sv}")
24 }
25 }
26 })
27}
28
29fn dict_name_for_field(
30 f: &Field,
31 rename_attr: Option<String>,
32 rename_all_attr: Option<&str>,
33) -> Result<String, Error> {
34 if let Some(name) = rename_attr {
35 Ok(name)
36 } else {
37 let ident = f.ident.as_ref().unwrap().to_string();
38
39 match rename_all_attr {
40 Some("lowercase") => Ok(ident.to_ascii_lowercase()),
41 Some("UPPERCASE") => Ok(ident.to_ascii_uppercase()),
42 Some("PascalCase") => Ok(case::pascal_or_camel_case(&ident, true)),
43 Some("camelCase") => Ok(case::pascal_or_camel_case(&ident, false)),
44 Some("snake_case") => Ok(case::snake_case(&ident)),
45 None => Ok(ident),
46 Some(other) => Err(Error::new(
47 f.span(),
48 format!("invalid `rename_all` attribute value {other}"),
49 )),
50 }
51 }
52}
53
54pub fn expand_serialize_derive(input: DeriveInput) -> Result<TokenStream, Error> {
55 let (name, data) = match input.data {
56 Data::Struct(data) => (input.ident, data),
57 _ => return Err(Error::new(input.span(), "only structs supported")),
58 };
59
60 let StructAttributes { rename_all, .. } = StructAttributes::parse(&input.attrs)?;
61
62 let zv = zvariant_path();
63 let mut entries = quote! {};
64 let mut num_entries: usize = 0;
65
66 for f in &data.fields {
67 let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?;
68
69 let name = &f.ident;
70 let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?;
71
72 let is_option = macros::ty_is_option(&f.ty);
73
74 let e = if is_option {
75 quote! {
76 if self.#name.is_some() {
77 map.serialize_entry(#dict_name, &#zv::SerializeValue(self.#name.as_ref().unwrap()))?;
78 }
79 }
80 } else {
81 quote! {
82 map.serialize_entry(#dict_name, &#zv::SerializeValue(&self.#name))?;
83 }
84 };
85
86 entries.extend(e);
87 num_entries += 1;
88 }
89
90 let generics = input.generics;
91 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
92
93 let num_entries = num_entries.to_token_stream();
94 Ok(quote! {
95 #[allow(deprecated)]
96 impl #impl_generics #zv::export::serde::ser::Serialize for #name #ty_generics
97 #where_clause
98 {
99 fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
100 where
101 S: #zv::export::serde::ser::Serializer,
102 {
103 use #zv::export::serde::ser::SerializeMap;
104
105 let mut map = serializer.serialize_map(::std::option::Option::Some(#num_entries))?;
107 #entries
108 map.end()
109 }
110 }
111 })
112}
113
114pub fn expand_deserialize_derive(input: DeriveInput) -> Result<TokenStream, Error> {
115 let (name, data) = match input.data {
116 Data::Struct(data) => (input.ident, data),
117 _ => return Err(Error::new(input.span(), "only structs supported")),
118 };
119
120 let StructAttributes {
121 rename_all,
122 deny_unknown_fields,
123 ..
124 } = StructAttributes::parse(&input.attrs)?;
125
126 let visitor = format_ident!("{}Visitor", name);
127 let zv = zvariant_path();
128 let mut fields = Vec::new();
129 let mut req_fields = Vec::new();
130 let mut dict_names = Vec::new();
131 let mut entries = Vec::new();
132
133 for f in &data.fields {
134 let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?;
135
136 let name = &f.ident;
137 let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?;
138
139 let is_option = macros::ty_is_option(&f.ty);
140
141 entries.push(quote! {
142 #dict_name => {
143 #name = access.next_value::<#zv::DeserializeValue<_>>().map(|v| v.0).ok();
145 }
146 });
147
148 dict_names.push(dict_name);
149 fields.push(name);
150
151 if !is_option {
152 req_fields.push(name);
153 }
154 }
155
156 let fallback = if deny_unknown_fields {
157 quote! {
158 field => {
159 return ::std::result::Result::Err(
160 <M::Error as #zv::export::serde::de::Error>::unknown_field(
161 field,
162 &[#(#dict_names),*],
163 ),
164 );
165 }
166 }
167 } else {
168 quote! {
169 unknown => {
170 let _ = access.next_value::<#zv::Value>();
171 }
172 }
173 };
174 entries.push(fallback);
175
176 let (_, ty_generics, _) = input.generics.split_for_impl();
177 let mut generics = input.generics.clone();
178 let def = syn::LifetimeDef {
179 attrs: Vec::new(),
180 lifetime: syn::Lifetime::new("'de", Span::call_site()),
181 colon_token: None,
182 bounds: Punctuated::new(),
183 };
184 generics.params = Some(syn::GenericParam::Lifetime(def))
185 .into_iter()
186 .chain(generics.params)
187 .collect();
188
189 let (impl_generics, _, where_clause) = generics.split_for_impl();
190
191 Ok(quote! {
192 #[allow(deprecated)]
193 impl #impl_generics #zv::export::serde::de::Deserialize<'de> for #name #ty_generics
194 #where_clause
195 {
196 fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
197 where
198 D: #zv::export::serde::de::Deserializer<'de>,
199 {
200 struct #visitor #ty_generics(::std::marker::PhantomData<#name #ty_generics>);
201
202 impl #impl_generics #zv::export::serde::de::Visitor<'de> for #visitor #ty_generics {
203 type Value = #name #ty_generics;
204
205 fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
206 formatter.write_str("a dictionary")
207 }
208
209 fn visit_map<M>(
210 self,
211 mut access: M,
212 ) -> ::std::result::Result<Self::Value, M::Error>
213 where
214 M: #zv::export::serde::de::MapAccess<'de>,
215 {
216 #( let mut #fields = ::std::default::Default::default(); )*
217
218 while let ::std::option::Option::Some(key) = access.next_key::<&str>()? {
220 match key {
221 #(#entries)*
222 }
223 }
224
225 #(let #req_fields = if let ::std::option::Option::Some(val) = #req_fields {
226 val
227 } else {
228 return ::std::result::Result::Err(
229 <M::Error as #zv::export::serde::de::Error>::missing_field(
230 ::std::stringify!(#req_fields),
231 ),
232 );
233 };)*
234
235 ::std::result::Result::Ok(#name { #(#fields),* })
236 }
237 }
238
239
240 deserializer.deserialize_map(#visitor(::std::marker::PhantomData))
241 }
242 }
243 })
244}