auto_enums/auto_enum/
mod.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3mod context;
4mod expr;
5#[cfg(feature = "type_analysis")]
6mod type_analysis;
7mod visitor;
8
9use proc_macro2::TokenStream;
10use quote::ToTokens as _;
11#[cfg(feature = "type_analysis")]
12use syn::Pat;
13use syn::{
14    AngleBracketedGenericArguments, Error, Expr, ExprClosure, GenericArgument, Item, ItemEnum,
15    ItemFn, Local, LocalInit, PathArguments, ReturnType, Stmt, Type, TypePath,
16};
17
18use self::{
19    context::{Context, VisitLastMode, VisitMode, DEFAULT_MARKER},
20    expr::child_expr,
21};
22use crate::utils::{block, expr_block, path_eq, replace_expr};
23
24/// The attribute name.
25const NAME: &str = "auto_enum";
26/// The annotation for recursively parsing.
27const NESTED: &str = "nested";
28/// The annotation for skipping branch.
29const NEVER: &str = "never";
30
31pub(crate) fn attribute(args: TokenStream, input: TokenStream) -> TokenStream {
32    let mut cx = match Context::root(input.clone(), args) {
33        Ok(cx) => cx,
34        Err(e) => return e.to_compile_error(),
35    };
36
37    match syn::parse2::<Stmt>(input.clone()) {
38        Ok(mut stmt) => {
39            expand_parent_stmt(&mut cx, &mut stmt);
40            cx.check().map(|()| stmt.into_token_stream())
41        }
42        Err(e) => match syn::parse2::<Expr>(input) {
43            Err(_e) => {
44                cx.error(e);
45                cx.error(format_err!(
46                    cx.span,
47                    "may only be used on expression, statement, or function"
48                ));
49                cx.check().map(|()| unreachable!())
50            }
51            Ok(mut expr) => {
52                expand_parent_expr(&mut cx, &mut expr, false);
53                cx.check().map(|()| expr.into_token_stream())
54            }
55        },
56    }
57    .unwrap_or_else(Error::into_compile_error)
58}
59
60fn expand_expr(cx: &mut Context, expr: &mut Expr) {
61    let expr = match expr {
62        Expr::Closure(ExprClosure { body, .. }) if cx.visit_last() => {
63            let count = visitor::visit_fn(cx, &mut **body);
64            if count.try_ >= 2 {
65                cx.visit_mode = VisitMode::Try;
66            } else {
67                cx.visit_mode = VisitMode::Return(count.return_);
68            }
69            &mut **body
70        }
71        _ => expr,
72    };
73
74    child_expr(cx, expr);
75
76    #[cfg(feature = "type_analysis")]
77    {
78        if let VisitMode::Return(count) = cx.visit_mode {
79            if cx.args.is_empty() && cx.variant_is_empty() && count < 2 {
80                cx.dummy(expr);
81                return;
82            }
83        }
84    }
85
86    cx.visitor(expr);
87}
88
89fn build_expr(expr: &mut Expr, item: ItemEnum) {
90    replace_expr(expr, |expr| {
91        expr_block(block(vec![Stmt::Item(item.into()), Stmt::Expr(expr, None)]))
92    });
93}
94
95// -----------------------------------------------------------------------------
96// Expand statement or expression in which `#[auto_enum]` was directly used.
97
98fn expand_parent_stmt(cx: &mut Context, stmt: &mut Stmt) {
99    match stmt {
100        Stmt::Expr(expr, semi) => expand_parent_expr(cx, expr, semi.is_some()),
101        Stmt::Local(local) => expand_parent_local(cx, local),
102        Stmt::Item(Item::Fn(item)) => expand_parent_item_fn(cx, item),
103        Stmt::Item(item) => {
104            cx.error(format_err!(item, "may only be used on expression, statement, or function"));
105        }
106        Stmt::Macro(_) => {}
107    }
108}
109
110fn expand_parent_expr(cx: &mut Context, expr: &mut Expr, has_semi: bool) {
111    if has_semi {
112        cx.visit_last_mode = VisitLastMode::Never;
113    }
114
115    if cx.is_dummy() {
116        cx.dummy(expr);
117        return;
118    }
119
120    expand_expr(cx, expr);
121
122    cx.build(|item| build_expr(expr, item));
123}
124
125fn expand_parent_local(cx: &mut Context, local: &mut Local) {
126    #[cfg(feature = "type_analysis")]
127    {
128        if let Pat::Type(pat) = &mut local.pat {
129            if cx.collect_impl_trait(&mut pat.ty) {
130                local.pat = (*pat.pat).clone();
131            }
132        }
133    }
134
135    if cx.is_dummy() {
136        cx.dummy(local);
137        return;
138    }
139
140    let expr = if let Some(LocalInit { expr, .. }) = &mut local.init {
141        &mut **expr
142    } else {
143        cx.error(format_err!(
144            local,
145            "the `#[auto_enum]` attribute is not supported uninitialized let statement"
146        ));
147        return;
148    };
149
150    expand_expr(cx, expr);
151
152    cx.build(|item| build_expr(expr, item));
153}
154
155fn expand_parent_item_fn(cx: &mut Context, item: &mut ItemFn) {
156    let ItemFn { sig, block, .. } = item;
157    if let ReturnType::Type(_, ty) = &mut sig.output {
158        match &**ty {
159            // `return`
160            Type::ImplTrait(_) if cx.visit_last_mode != VisitLastMode::Never => {
161                let count = visitor::visit_fn(cx, &mut **block);
162                cx.visit_mode = VisitMode::Return(count.return_);
163            }
164
165            // `?` operator
166            Type::Path(TypePath { qself: None, path })
167                if cx.visit_last_mode != VisitLastMode::Never =>
168            {
169                let ty = path.segments.last().unwrap();
170                match &ty.arguments {
171                    // `Result<T, impl Trait>`
172                    PathArguments::AngleBracketed(AngleBracketedGenericArguments {
173                        colon2_token: None,
174                        args,
175                        ..
176                    }) if args.len() == 2
177                        && path_eq(path, &["std", "core"], &["result", "Result"]) =>
178                    {
179                        if let (
180                            GenericArgument::Type(_),
181                            GenericArgument::Type(Type::ImplTrait(_)),
182                        ) = (&args[0], &args[1])
183                        {
184                            let count = visitor::visit_fn(cx, &mut **block);
185                            if count.try_ >= 2 {
186                                cx.visit_mode = VisitMode::Try;
187                            }
188                        }
189                    }
190                    _ => {}
191                }
192            }
193
194            _ => {}
195        }
196
197        #[cfg(feature = "type_analysis")]
198        cx.collect_impl_trait(&mut *ty);
199    }
200
201    if cx.is_dummy() {
202        cx.dummy(item);
203        return;
204    }
205
206    match item.block.stmts.last_mut() {
207        Some(Stmt::Expr(expr, None)) => child_expr(cx, expr),
208        Some(_) => {}
209        None => cx.error(format_err!(
210            item.block,
211            "the `#[auto_enum]` attribute is not supported empty functions"
212        )),
213    }
214
215    #[cfg(feature = "type_analysis")]
216    {
217        if let VisitMode::Return(count) = cx.visit_mode {
218            if cx.args.is_empty() && cx.variant_is_empty() && count < 2 {
219                cx.dummy(item);
220                return;
221            }
222        }
223    }
224
225    cx.visitor(item);
226
227    cx.build(|i| item.block.stmts.insert(0, Stmt::Item(i.into())));
228}