async_recursion/
parse.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
use proc_macro2::Span;
use syn::{
    parse::{Error, Parse, ParseStream, Result},
    token::Question,
    ItemFn, Token,
};

pub struct AsyncItem(pub ItemFn);

impl Parse for AsyncItem {
    fn parse(input: ParseStream) -> Result<Self> {
        let item: ItemFn = input.parse()?;

        // Check that this is an async function
        if item.sig.asyncness.is_none() {
            return Err(Error::new(Span::call_site(), "expected an async function"));
        }

        Ok(AsyncItem(item))
    }
}

pub struct RecursionArgs {
    pub send_bound: bool,
    pub sync_bound: bool,
}

/// Custom keywords for parser
mod kw {
    syn::custom_keyword!(Send);
    syn::custom_keyword!(Sync);
}

#[derive(Debug, PartialEq, Eq)]
enum Arg {
    NotSend,
    Sync,
}

impl std::fmt::Display for Arg {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::NotSend => write!(f, "?Send"),
            Self::Sync => write!(f, "Sync"),
        }
    }
}

impl Parse for Arg {
    fn parse(input: ParseStream) -> Result<Self> {
        if input.peek(Token![?]) {
            input.parse::<Question>()?;
            input.parse::<kw::Send>()?;
            Ok(Arg::NotSend)
        } else {
            input.parse::<kw::Sync>()?;
            Ok(Arg::Sync)
        }
    }
}

impl Parse for RecursionArgs {
    fn parse(input: ParseStream) -> Result<Self> {
        let mut send_bound: bool = true;
        let mut sync_bound: bool = false;

        let args_parsed: Vec<Arg> =
            syn::punctuated::Punctuated::<Arg, syn::Token![,]>::parse_terminated(input)
                .map_err(|e| input.error(format!("failed to parse macro arguments: {e}")))?
                .into_iter()
                .collect();

        // Avoid sloppy input
        if args_parsed.len() > 2 {
            return Err(Error::new(Span::call_site(), "received too many arguments"));
        } else if args_parsed.len() == 2 && args_parsed[0] == args_parsed[1] {
            return Err(Error::new(
                Span::call_site(),
                format!("received duplicate argument: `{}`", args_parsed[0]),
            ));
        }

        for arg in args_parsed {
            match arg {
                Arg::NotSend => send_bound = false,
                Arg::Sync => sync_bound = true,
            }
        }

        Ok(Self {
            send_bound,
            sync_bound,
        })
    }
}