auto_enums/auto_enum/
visitor.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use proc_macro2::TokenStream;
4use quote::ToTokens as _;
5use syn::{
6    parse_quote, token,
7    visit_mut::{self, VisitMut},
8    Arm, Attribute, Expr, ExprMacro, ExprMatch, ExprReturn, ExprTry, Item, Local, LocalInit,
9    MetaList, Stmt, Token,
10};
11
12use super::{Context, VisitMode, DEFAULT_MARKER, NAME, NESTED, NEVER};
13use crate::utils::{replace_expr, Attrs, Node};
14
15#[derive(Clone, Copy, Default)]
16struct Scope {
17    /// in closures
18    closure: bool,
19    /// in try blocks
20    try_block: bool,
21    /// in the other `auto_enum` attributes
22    foreign: bool,
23}
24
25impl Scope {
26    // check this scope is in closures or try blocks.
27    fn check_expr(&mut self, expr: &Expr) {
28        match expr {
29            Expr::Closure(_) => self.closure = true,
30            // `?` operator in try blocks are not supported.
31            Expr::TryBlock(_) => self.try_block = true,
32            _ => {}
33        }
34    }
35}
36
37// -----------------------------------------------------------------------------
38// default visitor
39
40pub(super) struct Visitor<'a> {
41    cx: &'a mut Context,
42    scope: Scope,
43}
44
45impl<'a> Visitor<'a> {
46    pub(super) fn new(cx: &'a mut Context) -> Self {
47        Self { cx, scope: Scope::default() }
48    }
49
50    fn find_remove_attrs(&self, attrs: &mut impl Attrs) {
51        if !self.scope.foreign {
52            if let Some(attr) = attrs.find_remove_attr(NEVER) {
53                if let Err(e) = attr.meta.require_path_only() {
54                    self.cx.error(e);
55                }
56            }
57
58            // The old annotation `#[rec]` is replaced with `#[nested]`.
59            if let Some(old) = attrs.find_remove_attr("rec") {
60                self.cx.error(format_err!(
61                    old,
62                    "#[rec] has been removed and replaced with #[{}]",
63                    NESTED
64                ));
65            }
66        }
67    }
68
69    /// `return` in functions or closures
70    fn visit_return(&mut self, node: &mut Expr, count: usize) {
71        debug_assert!(self.cx.visit_mode == VisitMode::Return(count));
72
73        if !self.scope.closure && !node.any_empty_attr(NEVER) {
74            // Desugar `return <expr>` into `return Enum::VariantN(<expr>)`.
75            if let Expr::Return(ExprReturn { expr, .. }) = node {
76                // Skip if `<expr>` is a marker macro.
77                if expr.as_ref().map_or(true, |expr| !self.cx.is_marker_expr(expr)) {
78                    self.cx.replace_boxed_expr(expr);
79                }
80            }
81        }
82    }
83
84    /// `?` operator in functions or closures
85    fn visit_try(&mut self, node: &mut Expr) {
86        debug_assert!(self.cx.visit_mode == VisitMode::Try);
87
88        if !self.scope.try_block && !self.scope.closure && !node.any_empty_attr(NEVER) {
89            match &node {
90                // https://github.com/rust-lang/rust/blob/1.35.0/src/librustc/hir/lowering.rs#L4578-L4682
91
92                // Desugar `<expr>?`
93                // into:
94                //
95                // match <expr> {
96                //     Ok(val) => val,
97                //     Err(err) => return Err(Enum::VariantN(err)),
98                // }
99                //
100                // Skip if `<expr>` is a marker macro.
101                Expr::Try(ExprTry { expr, .. }) if !self.cx.is_marker_expr(expr) => {
102                    replace_expr(node, |expr| {
103                        let ExprTry { attrs, expr, .. } =
104                            if let Expr::Try(expr) = expr { expr } else { unreachable!() };
105
106                        let err = self.cx.next_expr(parse_quote!(err));
107                        let arms = vec![
108                            parse_quote! {
109                                ::core::result::Result::Ok(val) => val,
110                            },
111                            parse_quote! {
112                                ::core::result::Result::Err(err) => {
113                                    return ::core::result::Result::Err(#err);
114                                }
115                            },
116                        ];
117
118                        Expr::Match(ExprMatch {
119                            attrs,
120                            match_token: <Token![match]>::default(),
121                            expr,
122                            brace_token: token::Brace::default(),
123                            arms,
124                        })
125                    });
126                }
127                _ => {}
128            }
129        }
130    }
131
132    /// `#[nested]`
133    fn visit_nested(&mut self, node: &mut Expr, attr: &Attribute) {
134        debug_assert!(!self.scope.foreign);
135
136        if let Err(e) = attr.meta.require_path_only() {
137            self.cx.error(e);
138        } else {
139            super::expr::child_expr(self.cx, node);
140        }
141    }
142
143    /// Expression level marker (`marker!` macro)
144    fn visit_marker_macro(&mut self, node: &mut Expr) {
145        debug_assert!(!self.scope.foreign || self.cx.current_marker != DEFAULT_MARKER);
146
147        match node {
148            // Desugar `marker!(<expr>)` into `Enum::VariantN(<expr>)`.
149            // Skip if `marker!` is not a marker macro.
150            Expr::Macro(ExprMacro { mac, .. }) if self.cx.is_marker_macro_exact(mac) => {
151                replace_expr(node, |expr| {
152                    let expr = if let Expr::Macro(expr) = expr { expr } else { unreachable!() };
153                    let args = syn::parse2(expr.mac.tokens).unwrap_or_else(|e| {
154                        self.cx.error(e);
155                        // Generate an expression to fill in where the error occurred during the visit.
156                        // These will eventually need to be replaced with the original error message.
157                        parse_quote!(compile_error!(
158                            "#[auto_enum] failed to generate error message"
159                        ))
160                    });
161
162                    if self.cx.has_error() {
163                        args
164                    } else {
165                        self.cx.next_expr_with_attrs(expr.attrs, args)
166                    }
167                });
168            }
169            _ => {}
170        }
171    }
172
173    fn visit_expr(&mut self, node: &mut Expr, has_semi: bool) {
174        debug_assert!(!self.cx.has_error());
175
176        let tmp = self.scope;
177
178        if node.any_attr(NAME) {
179            self.scope.foreign = true;
180            // Record whether other `auto_enum` attribute exists.
181            self.cx.has_child = true;
182        }
183        self.scope.check_expr(node);
184
185        match self.cx.visit_mode {
186            VisitMode::Return(count) => self.visit_return(node, count),
187            VisitMode::Try => self.visit_try(node),
188            VisitMode::Default => {}
189        }
190
191        if !self.scope.foreign {
192            if let Some(attr) = node.find_remove_attr(NESTED) {
193                self.visit_nested(node, &attr);
194            }
195        }
196
197        VisitStmt::visit_expr(self, node, has_semi);
198
199        if !self.scope.foreign || self.cx.current_marker != DEFAULT_MARKER {
200            self.visit_marker_macro(node);
201            self.find_remove_attrs(node);
202        }
203
204        self.scope = tmp;
205    }
206}
207
208impl VisitMut for Visitor<'_> {
209    fn visit_expr_mut(&mut self, node: &mut Expr) {
210        if !self.cx.has_error() {
211            self.visit_expr(node, false);
212        }
213    }
214
215    fn visit_arm_mut(&mut self, node: &mut Arm) {
216        if !self.cx.has_error() {
217            if !self.scope.foreign {
218                if let Some(attr) = node.find_remove_attr(NESTED) {
219                    self.visit_nested(&mut node.body, &attr);
220                }
221            }
222
223            visit_mut::visit_arm_mut(self, node);
224
225            self.find_remove_attrs(node);
226        }
227    }
228
229    fn visit_local_mut(&mut self, node: &mut Local) {
230        if !self.cx.has_error() {
231            if !self.scope.foreign {
232                if let Some(attr) = node.find_remove_attr(NESTED) {
233                    if let Some(LocalInit { expr, .. }) = &mut node.init {
234                        self.visit_nested(expr, &attr);
235                    }
236                }
237            }
238
239            visit_mut::visit_local_mut(self, node);
240
241            self.find_remove_attrs(node);
242        }
243    }
244
245    fn visit_stmt_mut(&mut self, node: &mut Stmt) {
246        if !self.cx.has_error() {
247            if let Stmt::Expr(expr, semi) = node {
248                self.visit_expr(expr, semi.is_some());
249            } else {
250                let tmp = self.scope;
251
252                if node.any_attr(NAME) {
253                    self.scope.foreign = true;
254                    // Record whether other `auto_enum` attribute exists.
255                    self.cx.has_child = true;
256                }
257
258                VisitStmt::visit_stmt(self, node);
259
260                self.scope = tmp;
261            }
262        }
263    }
264
265    fn visit_item_mut(&mut self, _: &mut Item) {
266        // Do not recurse into nested items.
267    }
268}
269
270impl VisitStmt for Visitor<'_> {
271    fn cx(&mut self) -> &mut Context {
272        self.cx
273    }
274}
275
276// -----------------------------------------------------------------------------
277// dummy visitor
278
279pub(super) struct Dummy<'a> {
280    cx: &'a mut Context,
281}
282
283impl<'a> Dummy<'a> {
284    pub(super) fn new(cx: &'a mut Context) -> Self {
285        Self { cx }
286    }
287}
288
289impl VisitMut for Dummy<'_> {
290    fn visit_stmt_mut(&mut self, node: &mut Stmt) {
291        if !self.cx.has_error() {
292            if node.any_attr(NAME) {
293                self.cx.has_child = true;
294            }
295            VisitStmt::visit_stmt(self, node);
296        }
297    }
298
299    fn visit_expr_mut(&mut self, node: &mut Expr) {
300        if !self.cx.has_error() {
301            if node.any_attr(NAME) {
302                self.cx.has_child = true;
303            }
304            VisitStmt::visit_expr(self, node, false);
305        }
306    }
307
308    fn visit_item_mut(&mut self, _: &mut Item) {
309        // Do not recurse into nested items.
310    }
311}
312
313impl VisitStmt for Dummy<'_> {
314    fn cx(&mut self) -> &mut Context {
315        self.cx
316    }
317}
318
319// -----------------------------------------------------------------------------
320// VisitStmt
321
322trait VisitStmt: VisitMut {
323    fn cx(&mut self) -> &mut Context;
324
325    fn visit_expr(visitor: &mut Self, node: &mut Expr, has_semi: bool) {
326        let attr = node.find_remove_attr(NAME);
327
328        let res = attr.map(|attr| {
329            attr.meta.require_list().and_then(|MetaList { tokens, .. }| {
330                visitor.cx().make_child(node.to_token_stream(), tokens.clone())
331            })
332        });
333
334        visit_mut::visit_expr_mut(visitor, node);
335
336        match res {
337            Some(Err(e)) => visitor.cx().error(e),
338            Some(Ok(mut cx)) => {
339                super::expand_parent_expr(&mut cx, node, has_semi);
340                visitor.cx().join_child(cx);
341            }
342            None => {}
343        }
344    }
345
346    fn visit_stmt(visitor: &mut Self, node: &mut Stmt) {
347        let attr = match node {
348            Stmt::Expr(expr, semi) => {
349                Self::visit_expr(visitor, expr, semi.is_some());
350                return;
351            }
352            Stmt::Local(local) => local.find_remove_attr(NAME),
353            Stmt::Macro(_) => None,
354            // Do not recurse into nested items.
355            Stmt::Item(_) => return,
356        };
357
358        let res = attr.map(|attr| {
359            let args = match attr.meta {
360                syn::Meta::Path(_) => TokenStream::new(),
361                syn::Meta::List(list) => list.tokens,
362                syn::Meta::NameValue(nv) => bail!(nv.eq_token, "expected list"),
363            };
364            visitor.cx().make_child(node.to_token_stream(), args)
365        });
366
367        visit_mut::visit_stmt_mut(visitor, node);
368
369        match res {
370            Some(Err(e)) => visitor.cx().error(e),
371            Some(Ok(mut cx)) => {
372                super::expand_parent_stmt(&mut cx, node);
373                visitor.cx().join_child(cx);
374            }
375            None => {}
376        }
377    }
378}
379
380// -----------------------------------------------------------------------------
381// FindNested
382
383/// Find `#[nested]` attribute.
384pub(super) fn find_nested(node: &mut impl Node) -> bool {
385    struct FindNested {
386        has: bool,
387    }
388
389    impl VisitMut for FindNested {
390        fn visit_expr_mut(&mut self, node: &mut Expr) {
391            if !node.any_attr(NAME) {
392                if node.any_empty_attr(NESTED) {
393                    self.has = true;
394                } else {
395                    visit_mut::visit_expr_mut(self, node);
396                }
397            }
398        }
399
400        fn visit_arm_mut(&mut self, node: &mut Arm) {
401            if node.any_empty_attr(NESTED) {
402                self.has = true;
403            } else {
404                visit_mut::visit_arm_mut(self, node);
405            }
406        }
407
408        fn visit_local_mut(&mut self, node: &mut Local) {
409            if !node.any_attr(NAME) {
410                if node.any_empty_attr(NESTED) {
411                    self.has = true;
412                } else {
413                    visit_mut::visit_local_mut(self, node);
414                }
415            }
416        }
417
418        fn visit_item_mut(&mut self, _: &mut Item) {
419            // Do not recurse into nested items.
420        }
421    }
422
423    let mut visitor = FindNested { has: false };
424    node.visited(&mut visitor);
425    visitor.has
426}
427
428// -----------------------------------------------------------------------------
429// FnVisitor
430
431#[derive(Default)]
432pub(super) struct FnCount {
433    pub(super) try_: usize,
434    pub(super) return_: usize,
435}
436
437pub(super) fn visit_fn(cx: &Context, node: &mut impl Node) -> FnCount {
438    struct FnVisitor<'a> {
439        cx: &'a Context,
440        scope: Scope,
441        count: FnCount,
442    }
443
444    impl VisitMut for FnVisitor<'_> {
445        fn visit_expr_mut(&mut self, node: &mut Expr) {
446            let tmp = self.scope;
447
448            self.scope.check_expr(node);
449
450            if !self.scope.closure && !node.any_empty_attr(NEVER) {
451                match node {
452                    Expr::Try(ExprTry { expr, .. }) => {
453                        // Skip if `<expr>` is a marker macro.
454                        if !self.cx.is_marker_expr(expr) {
455                            self.count.try_ += 1;
456                        }
457                    }
458                    Expr::Return(ExprReturn { expr, .. }) => {
459                        // Skip if `<expr>` is a marker macro.
460                        if expr.as_ref().map_or(true, |expr| !self.cx.is_marker_expr(expr)) {
461                            self.count.return_ += 1;
462                        }
463                    }
464                    _ => {}
465                }
466            }
467
468            if node.any_attr(NAME) {
469                self.scope.foreign = true;
470            }
471
472            visit_mut::visit_expr_mut(self, node);
473
474            self.scope = tmp;
475        }
476
477        fn visit_stmt_mut(&mut self, node: &mut Stmt) {
478            let tmp = self.scope;
479
480            if node.any_attr(NAME) {
481                self.scope.foreign = true;
482            }
483
484            visit_mut::visit_stmt_mut(self, node);
485
486            self.scope = tmp;
487        }
488
489        fn visit_item_mut(&mut self, _: &mut Item) {
490            // Do not recurse into nested items.
491        }
492    }
493
494    let mut visitor = FnVisitor { cx, scope: Scope::default(), count: FnCount::default() };
495    node.visited(&mut visitor);
496    visitor.count
497}