enumflags2_derive/
lib.rs

1#![recursion_limit = "2048"]
2extern crate proc_macro;
3#[macro_use]
4extern crate quote;
5
6use proc_macro2::{Span, TokenStream};
7use std::convert::TryFrom;
8use syn::{
9    parse::{Parse, ParseStream},
10    parse_macro_input,
11    spanned::Spanned,
12    Expr, Ident, DeriveInput, Data, Token, Variant,
13};
14
15struct Flag<'a> {
16    name: Ident,
17    span: Span,
18    value: FlagValue<'a>,
19}
20
21enum FlagValue<'a> {
22    Literal(u128),
23    Deferred,
24    Inferred(&'a mut Variant),
25}
26
27impl FlagValue<'_> {
28    fn is_inferred(&self) -> bool {
29        matches!(self, FlagValue::Inferred(_))
30    }
31}
32
33struct Parameters {
34    default: Vec<Ident>,
35}
36
37impl Parse for Parameters {
38    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
39        if input.is_empty() {
40            return Ok(Parameters { default: vec![] });
41        }
42
43        input.parse::<Token![default]>()?;
44        input.parse::<Token![=]>()?;
45        let mut default = vec![input.parse()?];
46        while !input.is_empty() {
47            input.parse::<Token![|]>()?;
48            default.push(input.parse()?);
49        }
50
51        Ok(Parameters { default })
52    }
53}
54
55#[proc_macro_attribute]
56pub fn bitflags_internal(
57    attr: proc_macro::TokenStream,
58    input: proc_macro::TokenStream,
59) -> proc_macro::TokenStream {
60    let Parameters { default } = parse_macro_input!(attr as Parameters);
61    let mut ast = parse_macro_input!(input as DeriveInput);
62    let output = gen_enumflags(&mut ast, default);
63
64    output
65        .unwrap_or_else(|err| {
66            let error = err.to_compile_error();
67            quote! {
68                #ast
69                #error
70            }
71        })
72        .into()
73}
74
75/// Try to evaluate the expression given.
76fn fold_expr(expr: &syn::Expr) -> Option<u128> {
77    match expr {
78        Expr::Lit(ref expr_lit) => match expr_lit.lit {
79            syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(),
80            _ => None,
81        },
82        Expr::Binary(ref expr_binary) => {
83            let l = fold_expr(&expr_binary.left)?;
84            let r = fold_expr(&expr_binary.right)?;
85            match &expr_binary.op {
86                syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r| l.checked_shl(r)),
87                _ => None,
88            }
89        }
90        Expr::Paren(syn::ExprParen { expr, .. }) | Expr::Group(syn::ExprGroup { expr, .. }) => {
91            fold_expr(expr)
92        }
93        _ => None,
94    }
95}
96
97fn collect_flags<'a>(
98    variants: impl Iterator<Item = &'a mut Variant>,
99) -> Result<Vec<Flag<'a>>, syn::Error> {
100    variants
101        .map(|variant| {
102            if !matches!(variant.fields, syn::Fields::Unit) {
103                return Err(syn::Error::new_spanned(
104                    &variant.fields,
105                    "Bitflag variants cannot contain additional data",
106                ));
107            }
108
109            let name = variant.ident.clone();
110            let span = variant.span();
111            let value = if let Some(ref expr) = variant.discriminant {
112                if let Some(n) = fold_expr(&expr.1) {
113                    FlagValue::Literal(n)
114                } else {
115                    FlagValue::Deferred
116                }
117            } else {
118                FlagValue::Inferred(variant)
119            };
120
121            Ok(Flag { name, span, value })
122        })
123        .collect()
124}
125
126fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr {
127    let tokens = if previous_variants.is_empty() {
128        quote!(1)
129    } else {
130        quote!(::enumflags2::_internal::next_bit(
131                #(#type_name::#previous_variants as u128)|*
132        ) as #repr)
133    };
134
135    syn::parse2(tokens).expect("couldn't parse inferred value")
136}
137
138fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) {
139    let mut previous_variants: Vec<Ident> = flags
140        .iter()
141        .filter(|flag| !flag.value.is_inferred())
142        .map(|flag| flag.name.clone())
143        .collect();
144
145    for flag in flags {
146        if let FlagValue::Inferred(ref mut variant) = flag.value {
147            variant.discriminant = Some((
148                <Token![=]>::default(),
149                inferred_value(type_name, &previous_variants, repr),
150            ));
151            previous_variants.push(flag.name.clone());
152        }
153    }
154}
155
156/// Given a list of attributes, find the `repr`, if any, and return the integer
157/// type specified.
158fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> {
159    let mut res = None;
160    for attr in attrs {
161        if attr.path().is_ident("repr") {
162            attr.parse_nested_meta(|meta| {
163                if let Some(ident) = meta.path.get_ident() {
164                    res = Some(ident.clone());
165                }
166                Ok(())
167            })?;
168        }
169    }
170    Ok(res)
171}
172
173/// Check the repr and return the number of bits available
174fn type_bits(ty: &Ident) -> Result<u8, syn::Error> {
175    // This would be so much easier if we could just match on an Ident...
176    if ty == "usize" {
177        Err(syn::Error::new_spanned(
178            ty,
179            "#[repr(usize)] is not supported. Use u32 or u64 instead.",
180        ))
181    } else if ty == "i8"
182        || ty == "i16"
183        || ty == "i32"
184        || ty == "i64"
185        || ty == "i128"
186        || ty == "isize"
187    {
188        Err(syn::Error::new_spanned(
189            ty,
190            "Signed types in a repr are not supported.",
191        ))
192    } else if ty == "u8" {
193        Ok(8)
194    } else if ty == "u16" {
195        Ok(16)
196    } else if ty == "u32" {
197        Ok(32)
198    } else if ty == "u64" {
199        Ok(64)
200    } else if ty == "u128" {
201        Ok(128)
202    } else {
203        Err(syn::Error::new_spanned(
204            ty,
205            "repr must be an integer type for #[bitflags].",
206        ))
207    }
208}
209
210/// Returns deferred checks
211fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> {
212    use FlagValue::*;
213    match flag.value {
214        Literal(n) => {
215            if !n.is_power_of_two() {
216                Err(syn::Error::new(
217                    flag.span,
218                    "Flags must have exactly one set bit",
219                ))
220            } else if bits < 128 && n >= 1 << bits {
221                Err(syn::Error::new(
222                    flag.span,
223                    format!("Flag value out of range for u{}", bits),
224                ))
225            } else {
226                Ok(None)
227            }
228        }
229        Inferred(_) => Ok(None),
230        Deferred => {
231            let variant_name = &flag.name;
232            Ok(Some(quote_spanned!(flag.span =>
233                const _:
234                    <<[(); (
235                        (#type_name::#variant_name as u128).is_power_of_two()
236                    ) as usize] as ::enumflags2::_internal::AssertionHelper>
237                        ::Status as ::enumflags2::_internal::ExactlyOneBitSet>::X
238                    = ();
239            )))
240        }
241    }
242}
243
244fn gen_enumflags(ast: &mut DeriveInput, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
245    let ident = &ast.ident;
246
247    let span = Span::call_site();
248
249    let ast_variants = match &mut ast.data {
250        Data::Enum(ref mut data) => &mut data.variants,
251        Data::Struct(data) => {
252            return Err(syn::Error::new_spanned(&data.struct_token,
253                "expected enum for #[bitflags], found struct"));
254        }
255        Data::Union(data) => {
256            return Err(syn::Error::new_spanned(&data.union_token,
257                "expected enum for #[bitflags], found union"));
258        }
259    };
260
261    if ast.generics.lt_token.is_some() || ast.generics.where_clause.is_some() {
262        return Err(syn::Error::new_spanned(&ast.generics,
263            "bitflags cannot be generic"));
264    }
265
266    let repr = extract_repr(&ast.attrs)?
267        .ok_or_else(|| syn::Error::new_spanned(ident,
268                        "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
269    let bits = type_bits(&repr)?;
270
271    let mut variants = collect_flags(ast_variants.iter_mut())?;
272    let deferred = variants
273        .iter()
274        .flat_map(|variant| check_flag(ident, variant, bits).transpose())
275        .collect::<Result<Vec<_>, _>>()?;
276
277    infer_values(&mut variants, ident, &repr);
278
279    if (bits as usize) < variants.len() {
280        return Err(syn::Error::new_spanned(
281            &repr,
282            format!("Not enough bits for {} flags", variants.len()),
283        ));
284    }
285
286    let std = quote_spanned!(span => ::enumflags2::_internal::core);
287    let ast_variants = match &ast.data {
288        Data::Enum(ref data) => &data.variants,
289        _ => unreachable!(),
290    };
291
292    let variant_names = ast_variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
293
294    Ok(quote_spanned! {
295        span =>
296            #ast
297            #(#deferred)*
298            impl #std::ops::Not for #ident {
299                type Output = ::enumflags2::BitFlags<Self>;
300                #[inline(always)]
301                fn not(self) -> Self::Output {
302                    use ::enumflags2::BitFlags;
303                    BitFlags::from_flag(self).not()
304                }
305            }
306
307            impl #std::ops::BitOr for #ident {
308                type Output = ::enumflags2::BitFlags<Self>;
309                #[inline(always)]
310                fn bitor(self, other: Self) -> Self::Output {
311                    use ::enumflags2::BitFlags;
312                    BitFlags::from_flag(self) | other
313                }
314            }
315
316            impl #std::ops::BitAnd for #ident {
317                type Output = ::enumflags2::BitFlags<Self>;
318                #[inline(always)]
319                fn bitand(self, other: Self) -> Self::Output {
320                    use ::enumflags2::BitFlags;
321                    BitFlags::from_flag(self) & other
322                }
323            }
324
325            impl #std::ops::BitXor for #ident {
326                type Output = ::enumflags2::BitFlags<Self>;
327                #[inline(always)]
328                fn bitxor(self, other: Self) -> Self::Output {
329                    use ::enumflags2::BitFlags;
330                    BitFlags::from_flag(self) ^ other
331                }
332            }
333
334            unsafe impl ::enumflags2::_internal::RawBitFlags for #ident {
335                type Numeric = #repr;
336
337                const EMPTY: Self::Numeric = 0;
338
339                const DEFAULT: Self::Numeric =
340                    0 #(| (Self::#default as #repr))*;
341
342                const ALL_BITS: Self::Numeric =
343                    0 #(| (Self::#variant_names as #repr))*;
344
345                const BITFLAGS_TYPE_NAME : &'static str =
346                    concat!("BitFlags<", stringify!(#ident), ">");
347
348                fn bits(self) -> Self::Numeric {
349                    self as #repr
350                }
351            }
352
353            impl ::enumflags2::BitFlag for #ident {}
354    })
355}