zvariant_derive/
value.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{
4    spanned::Spanned, Attribute, Data, DataEnum, DeriveInput, Error, Expr, Fields, Generics, Ident,
5    Lifetime, LifetimeDef,
6};
7
8use crate::utils::*;
9
10pub enum ValueType {
11    Value,
12    OwnedValue,
13}
14
15pub fn expand_derive(ast: DeriveInput, value_type: ValueType) -> Result<TokenStream, Error> {
16    let zv = zvariant_path();
17
18    match &ast.data {
19        Data::Struct(ds) => match &ds.fields {
20            Fields::Named(_) | Fields::Unnamed(_) => {
21                let StructAttributes { signature, .. } = StructAttributes::parse(&ast.attrs)?;
22                let signature = signature.map(|signature| match signature.as_str() {
23                    "dict" => "a{sv}".to_string(),
24                    _ => signature,
25                });
26
27                impl_struct(
28                    value_type,
29                    ast.ident,
30                    ast.generics,
31                    &ds.fields,
32                    signature,
33                    &zv,
34                )
35            }
36            Fields::Unit => Err(Error::new(ast.span(), "Unit structures not supported")),
37        },
38        Data::Enum(data) => impl_enum(value_type, ast.ident, ast.generics, ast.attrs, data, &zv),
39        _ => Err(Error::new(
40            ast.span(),
41            "only structs and enums are supported",
42        )),
43    }
44}
45
46fn impl_struct(
47    value_type: ValueType,
48    name: Ident,
49    generics: Generics,
50    fields: &Fields,
51    signature: Option<String>,
52    zv: &TokenStream,
53) -> Result<TokenStream, Error> {
54    let statc_lifetime = LifetimeDef::new(Lifetime::new("'static", Span::call_site()));
55    let (value_type, value_lifetime) = match value_type {
56        ValueType::Value => {
57            let mut lifetimes = generics.lifetimes();
58            let value_lifetime = lifetimes
59                .next()
60                .cloned()
61                .unwrap_or_else(|| statc_lifetime.clone());
62            if lifetimes.next().is_some() {
63                return Err(Error::new(
64                    name.span(),
65                    "Type with more than 1 lifetime not supported",
66                ));
67            }
68
69            (quote! { #zv::Value<#value_lifetime> }, value_lifetime)
70        }
71        ValueType::OwnedValue => (quote! { #zv::OwnedValue }, statc_lifetime),
72    };
73
74    let type_params = generics.type_params().cloned().collect::<Vec<_>>();
75    let (from_value_where_clause, into_value_where_clause) = if !type_params.is_empty() {
76        (
77            Some(quote! {
78                where
79                #(
80                    #type_params: ::std::convert::TryFrom<#zv::Value<#value_lifetime>> + #zv::Type
81                ),*
82            }),
83            Some(quote! {
84                where
85                #(
86                    #type_params: ::std::convert::Into<#zv::Value<#value_lifetime>> + #zv::Type
87                ),*
88            }),
89        )
90    } else {
91        (None, None)
92    };
93    let (impl_generics, ty_generics, _) = generics.split_for_impl();
94    match fields {
95        Fields::Named(_) => {
96            let field_names: Vec<_> = fields
97                .iter()
98                .map(|field| field.ident.to_token_stream())
99                .collect();
100            let (from_value_impl, into_value_impl) = match signature {
101                Some(signature) if signature == "a{sv}" => (
102                    // User wants the type to be encoded as a dict.
103                    // FIXME: Not the most efficient implementation.
104                    quote! {
105                        let mut fields = <::std::collections::HashMap::<::std::string::String, #zv::Value>>::try_from(value)?;
106
107                        ::std::result::Result::Ok(Self {
108                            #(
109                                #field_names:
110                                    fields
111                                        .remove(stringify!(#field_names))
112                                        .ok_or_else(|| #zv::Error::IncorrectType)?
113                                        .downcast()
114                                        .ok_or_else(|| #zv::Error::IncorrectType)?
115                            ),*
116                        })
117                    },
118                    quote! {
119                        let mut fields = ::std::collections::HashMap::new();
120                        #(
121                            fields.insert(stringify!(#field_names), #zv::Value::from(s.#field_names));
122                        )*
123
124                        #zv::Value::from(fields).into()
125                    },
126                ),
127                Some(_) | None => (
128                    quote! {
129                        let mut fields = #zv::Structure::try_from(value)?.into_fields();
130
131                        ::std::result::Result::Ok(Self {
132                            #(
133                                #field_names:
134                                    fields
135                                        .remove(0)
136                                        .downcast()
137                                        .ok_or_else(|| #zv::Error::IncorrectType)?
138                            ),*
139                        })
140                    },
141                    quote! {
142                        #zv::StructureBuilder::new()
143                        #(
144                            .add_field(s.#field_names)
145                        )*
146                        .build()
147                        .into()
148                    },
149                ),
150            };
151            Ok(quote! {
152                impl #impl_generics ::std::convert::TryFrom<#value_type> for #name #ty_generics
153                    #from_value_where_clause
154                {
155                    type Error = #zv::Error;
156
157                    #[inline]
158                    fn try_from(value: #value_type) -> #zv::Result<Self> {
159                        #from_value_impl
160                    }
161                }
162
163                impl #impl_generics From<#name #ty_generics> for #value_type
164                    #into_value_where_clause
165                {
166                    #[inline]
167                    fn from(s: #name #ty_generics) -> Self {
168                        #into_value_impl
169                    }
170                }
171            })
172        }
173        Fields::Unnamed(_) if fields.iter().next().is_some() => {
174            // Newtype struct.
175            Ok(quote! {
176                impl #impl_generics ::std::convert::TryFrom<#value_type> for #name #ty_generics
177                    #from_value_where_clause
178                {
179                    type Error = #zv::Error;
180
181                    #[inline]
182                    fn try_from(value: #value_type) -> #zv::Result<Self> {
183                        ::std::convert::TryInto::try_into(value).map(Self)
184                    }
185                }
186
187                impl #impl_generics From<#name #ty_generics> for #value_type
188                    #into_value_where_clause
189                {
190                    #[inline]
191                    fn from(s: #name #ty_generics) -> Self {
192                        s.0.into()
193                    }
194                }
195            })
196        }
197        Fields::Unnamed(_) => panic!("impl_struct must not be called for tuples"),
198        Fields::Unit => panic!("impl_struct must not be called for unit structures"),
199    }
200}
201
202fn impl_enum(
203    value_type: ValueType,
204    name: Ident,
205    _generics: Generics,
206    attrs: Vec<Attribute>,
207    data: &DataEnum,
208    zv: &TokenStream,
209) -> Result<TokenStream, Error> {
210    let repr: TokenStream = match attrs.iter().find(|attr| attr.path.is_ident("repr")) {
211        Some(repr_attr) => repr_attr.parse_args()?,
212        None => quote! { u32 },
213    };
214
215    let mut variant_names = vec![];
216    let mut variant_values = vec![];
217    for variant in &data.variants {
218        // Ensure all variants of the enum are unit type
219        match variant.fields {
220            Fields::Unit => {
221                variant_names.push(&variant.ident);
222                let value = match &variant
223                    .discriminant
224                    .as_ref()
225                    .ok_or_else(|| Error::new(variant.span(), "expected `Name = Value` variants"))?
226                    .1
227                {
228                    Expr::Lit(lit_exp) => &lit_exp.lit,
229                    _ => {
230                        return Err(Error::new(
231                            variant.span(),
232                            "expected `Name = Value` variants",
233                        ))
234                    }
235                };
236                variant_values.push(value);
237            }
238            _ => return Err(Error::new(variant.span(), "must be a unit variant")),
239        }
240    }
241
242    let value_type = match value_type {
243        ValueType::Value => quote! { #zv::Value<'_> },
244        ValueType::OwnedValue => quote! { #zv::OwnedValue },
245    };
246
247    Ok(quote! {
248        impl ::std::convert::TryFrom<#value_type> for #name {
249            type Error = #zv::Error;
250
251            #[inline]
252            fn try_from(value: #value_type) -> #zv::Result<Self> {
253                let v: #repr = ::std::convert::TryInto::try_into(value)?;
254
255                ::std::result::Result::Ok(match v {
256                    #(
257                        #variant_values => #name::#variant_names
258                     ),*,
259                    _ => return ::std::result::Result::Err(#zv::Error::IncorrectType),
260                })
261            }
262        }
263
264        impl ::std::convert::From<#name> for #value_type {
265            #[inline]
266            fn from(e: #name) -> Self {
267                let u: #repr = match e {
268                    #(
269                        #name::#variant_names => #variant_values
270                     ),*
271                };
272
273                <#zv::Value as ::std::convert::From<_>>::from(u).into()
274             }
275        }
276    })
277}