1#![recursion_limit = "2048"]
2extern crate proc_macro;
3#[macro_use]
4extern crate quote;
5
6use proc_macro2::{Span, TokenStream};
7use std::convert::TryFrom;
8use syn::{
9 parse::{Parse, ParseStream},
10 parse_macro_input,
11 spanned::Spanned,
12 Expr, Ident, DeriveInput, Data, Token, Variant,
13};
14
15struct Flag<'a> {
16 name: Ident,
17 span: Span,
18 value: FlagValue<'a>,
19}
20
21enum FlagValue<'a> {
22 Literal(u128),
23 Deferred,
24 Inferred(&'a mut Variant),
25}
26
27impl FlagValue<'_> {
28 fn is_inferred(&self) -> bool {
29 matches!(self, FlagValue::Inferred(_))
30 }
31}
32
33struct Parameters {
34 default: Vec<Ident>,
35}
36
37impl Parse for Parameters {
38 fn parse(input: ParseStream) -> syn::parse::Result<Self> {
39 if input.is_empty() {
40 return Ok(Parameters { default: vec![] });
41 }
42
43 input.parse::<Token![default]>()?;
44 input.parse::<Token![=]>()?;
45 let mut default = vec![input.parse()?];
46 while !input.is_empty() {
47 input.parse::<Token![|]>()?;
48 default.push(input.parse()?);
49 }
50
51 Ok(Parameters { default })
52 }
53}
54
55#[proc_macro_attribute]
56pub fn bitflags_internal(
57 attr: proc_macro::TokenStream,
58 input: proc_macro::TokenStream,
59) -> proc_macro::TokenStream {
60 let Parameters { default } = parse_macro_input!(attr as Parameters);
61 let mut ast = parse_macro_input!(input as DeriveInput);
62 let output = gen_enumflags(&mut ast, default);
63
64 output
65 .unwrap_or_else(|err| {
66 let error = err.to_compile_error();
67 quote! {
68 #ast
69 #error
70 }
71 })
72 .into()
73}
74
75fn fold_expr(expr: &syn::Expr) -> Option<u128> {
77 match expr {
78 Expr::Lit(ref expr_lit) => match expr_lit.lit {
79 syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(),
80 _ => None,
81 },
82 Expr::Binary(ref expr_binary) => {
83 let l = fold_expr(&expr_binary.left)?;
84 let r = fold_expr(&expr_binary.right)?;
85 match &expr_binary.op {
86 syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r| l.checked_shl(r)),
87 _ => None,
88 }
89 }
90 Expr::Paren(syn::ExprParen { expr, .. }) | Expr::Group(syn::ExprGroup { expr, .. }) => {
91 fold_expr(expr)
92 }
93 _ => None,
94 }
95}
96
97fn collect_flags<'a>(
98 variants: impl Iterator<Item = &'a mut Variant>,
99) -> Result<Vec<Flag<'a>>, syn::Error> {
100 variants
101 .map(|variant| {
102 if !matches!(variant.fields, syn::Fields::Unit) {
103 return Err(syn::Error::new_spanned(
104 &variant.fields,
105 "Bitflag variants cannot contain additional data",
106 ));
107 }
108
109 let name = variant.ident.clone();
110 let span = variant.span();
111 let value = if let Some(ref expr) = variant.discriminant {
112 if let Some(n) = fold_expr(&expr.1) {
113 FlagValue::Literal(n)
114 } else {
115 FlagValue::Deferred
116 }
117 } else {
118 FlagValue::Inferred(variant)
119 };
120
121 Ok(Flag { name, span, value })
122 })
123 .collect()
124}
125
126fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr {
127 let tokens = if previous_variants.is_empty() {
128 quote!(1)
129 } else {
130 quote!(::enumflags2::_internal::next_bit(
131 #(#type_name::#previous_variants as u128)|*
132 ) as #repr)
133 };
134
135 syn::parse2(tokens).expect("couldn't parse inferred value")
136}
137
138fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) {
139 let mut previous_variants: Vec<Ident> = flags
140 .iter()
141 .filter(|flag| !flag.value.is_inferred())
142 .map(|flag| flag.name.clone())
143 .collect();
144
145 for flag in flags {
146 if let FlagValue::Inferred(ref mut variant) = flag.value {
147 variant.discriminant = Some((
148 <Token![=]>::default(),
149 inferred_value(type_name, &previous_variants, repr),
150 ));
151 previous_variants.push(flag.name.clone());
152 }
153 }
154}
155
156fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> {
159 let mut res = None;
160 for attr in attrs {
161 if attr.path().is_ident("repr") {
162 attr.parse_nested_meta(|meta| {
163 if let Some(ident) = meta.path.get_ident() {
164 res = Some(ident.clone());
165 }
166 Ok(())
167 })?;
168 }
169 }
170 Ok(res)
171}
172
173fn type_bits(ty: &Ident) -> Result<u8, syn::Error> {
175 if ty == "usize" {
177 Err(syn::Error::new_spanned(
178 ty,
179 "#[repr(usize)] is not supported. Use u32 or u64 instead.",
180 ))
181 } else if ty == "i8"
182 || ty == "i16"
183 || ty == "i32"
184 || ty == "i64"
185 || ty == "i128"
186 || ty == "isize"
187 {
188 Err(syn::Error::new_spanned(
189 ty,
190 "Signed types in a repr are not supported.",
191 ))
192 } else if ty == "u8" {
193 Ok(8)
194 } else if ty == "u16" {
195 Ok(16)
196 } else if ty == "u32" {
197 Ok(32)
198 } else if ty == "u64" {
199 Ok(64)
200 } else if ty == "u128" {
201 Ok(128)
202 } else {
203 Err(syn::Error::new_spanned(
204 ty,
205 "repr must be an integer type for #[bitflags].",
206 ))
207 }
208}
209
210fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> {
212 use FlagValue::*;
213 match flag.value {
214 Literal(n) => {
215 if !n.is_power_of_two() {
216 Err(syn::Error::new(
217 flag.span,
218 "Flags must have exactly one set bit",
219 ))
220 } else if bits < 128 && n >= 1 << bits {
221 Err(syn::Error::new(
222 flag.span,
223 format!("Flag value out of range for u{}", bits),
224 ))
225 } else {
226 Ok(None)
227 }
228 }
229 Inferred(_) => Ok(None),
230 Deferred => {
231 let variant_name = &flag.name;
232 Ok(Some(quote_spanned!(flag.span =>
233 const _:
234 <<[(); (
235 (#type_name::#variant_name as u128).is_power_of_two()
236 ) as usize] as ::enumflags2::_internal::AssertionHelper>
237 ::Status as ::enumflags2::_internal::ExactlyOneBitSet>::X
238 = ();
239 )))
240 }
241 }
242}
243
244fn gen_enumflags(ast: &mut DeriveInput, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
245 let ident = &ast.ident;
246
247 let span = Span::call_site();
248
249 let ast_variants = match &mut ast.data {
250 Data::Enum(ref mut data) => &mut data.variants,
251 Data::Struct(data) => {
252 return Err(syn::Error::new_spanned(&data.struct_token,
253 "expected enum for #[bitflags], found struct"));
254 }
255 Data::Union(data) => {
256 return Err(syn::Error::new_spanned(&data.union_token,
257 "expected enum for #[bitflags], found union"));
258 }
259 };
260
261 if ast.generics.lt_token.is_some() || ast.generics.where_clause.is_some() {
262 return Err(syn::Error::new_spanned(&ast.generics,
263 "bitflags cannot be generic"));
264 }
265
266 let repr = extract_repr(&ast.attrs)?
267 .ok_or_else(|| syn::Error::new_spanned(ident,
268 "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
269 let bits = type_bits(&repr)?;
270
271 let mut variants = collect_flags(ast_variants.iter_mut())?;
272 let deferred = variants
273 .iter()
274 .flat_map(|variant| check_flag(ident, variant, bits).transpose())
275 .collect::<Result<Vec<_>, _>>()?;
276
277 infer_values(&mut variants, ident, &repr);
278
279 if (bits as usize) < variants.len() {
280 return Err(syn::Error::new_spanned(
281 &repr,
282 format!("Not enough bits for {} flags", variants.len()),
283 ));
284 }
285
286 let std = quote_spanned!(span => ::enumflags2::_internal::core);
287 let ast_variants = match &ast.data {
288 Data::Enum(ref data) => &data.variants,
289 _ => unreachable!(),
290 };
291
292 let variant_names = ast_variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
293
294 Ok(quote_spanned! {
295 span =>
296 #ast
297 #(#deferred)*
298 impl #std::ops::Not for #ident {
299 type Output = ::enumflags2::BitFlags<Self>;
300 #[inline(always)]
301 fn not(self) -> Self::Output {
302 use ::enumflags2::BitFlags;
303 BitFlags::from_flag(self).not()
304 }
305 }
306
307 impl #std::ops::BitOr for #ident {
308 type Output = ::enumflags2::BitFlags<Self>;
309 #[inline(always)]
310 fn bitor(self, other: Self) -> Self::Output {
311 use ::enumflags2::BitFlags;
312 BitFlags::from_flag(self) | other
313 }
314 }
315
316 impl #std::ops::BitAnd for #ident {
317 type Output = ::enumflags2::BitFlags<Self>;
318 #[inline(always)]
319 fn bitand(self, other: Self) -> Self::Output {
320 use ::enumflags2::BitFlags;
321 BitFlags::from_flag(self) & other
322 }
323 }
324
325 impl #std::ops::BitXor for #ident {
326 type Output = ::enumflags2::BitFlags<Self>;
327 #[inline(always)]
328 fn bitxor(self, other: Self) -> Self::Output {
329 use ::enumflags2::BitFlags;
330 BitFlags::from_flag(self) ^ other
331 }
332 }
333
334 unsafe impl ::enumflags2::_internal::RawBitFlags for #ident {
335 type Numeric = #repr;
336
337 const EMPTY: Self::Numeric = 0;
338
339 const DEFAULT: Self::Numeric =
340 0 #(| (Self::#default as #repr))*;
341
342 const ALL_BITS: Self::Numeric =
343 0 #(| (Self::#variant_names as #repr))*;
344
345 const BITFLAGS_TYPE_NAME : &'static str =
346 concat!("BitFlags<", stringify!(#ident), ">");
347
348 fn bits(self) -> Self::Numeric {
349 self as #repr
350 }
351 }
352
353 impl ::enumflags2::BitFlag for #ident {}
354 })
355}