async_recursion/
expand.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{
4    parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Block, Lifetime, Receiver,
5    ReturnType, Signature, TypeReference, WhereClause,
6};
7
8use crate::parse::{AsyncItem, RecursionArgs};
9
10impl ToTokens for AsyncItem {
11    fn to_tokens(&self, tokens: &mut TokenStream) {
12        self.0.to_tokens(tokens);
13    }
14}
15
16pub fn expand(item: &mut AsyncItem, args: &RecursionArgs) {
17    item.0.attrs.push(parse_quote!(#[must_use]));
18    transform_sig(&mut item.0.sig, args);
19    transform_block(&mut item.0.block);
20}
21
22fn transform_block(block: &mut Block) {
23    let brace = block.brace_token;
24    *block = parse_quote!({
25        Box::pin(async move #block)
26    });
27    block.brace_token = brace;
28}
29
30enum ArgLifetime {
31    New(Lifetime),
32    Existing(Lifetime),
33}
34
35impl ArgLifetime {
36    pub fn lifetime(self) -> Lifetime {
37        match self {
38            ArgLifetime::New(lt) | ArgLifetime::Existing(lt) => lt,
39        }
40    }
41}
42
43#[derive(Default)]
44struct ReferenceVisitor {
45    counter: usize,
46    lifetimes: Vec<ArgLifetime>,
47    self_receiver: bool,
48    self_receiver_new_lifetime: bool,
49    self_lifetime: Option<Lifetime>,
50}
51
52impl VisitMut for ReferenceVisitor {
53    fn visit_receiver_mut(&mut self, receiver: &mut Receiver) {
54        self.self_lifetime = Some(if let Some((_, lt)) = &mut receiver.reference {
55            self.self_receiver = true;
56
57            if let Some(lt) = lt {
58                lt.clone()
59            } else {
60                // Use 'life_self to avoid collisions with 'life<count> lifetimes.
61                let new_lifetime: Lifetime = parse_quote!('life_self);
62                lt.replace(new_lifetime.clone());
63
64                self.self_receiver_new_lifetime = true;
65
66                new_lifetime
67            }
68        } else {
69            return;
70        });
71    }
72
73    fn visit_type_reference_mut(&mut self, argument: &mut TypeReference) {
74        if argument.lifetime.is_none() {
75            // If this reference doesn't have a lifetime (e.g. &T), then give it one.
76            let lt = Lifetime::new(&format!("'life{}", self.counter), Span::call_site());
77            self.lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));
78            argument.lifetime = Some(lt);
79            self.counter += 1;
80        } else {
81            // If it does (e.g. &'life T), then keep track of it.
82            let lt = argument.lifetime.as_ref().cloned().unwrap();
83
84            // Check that this lifetime isn't already in our vector
85            let ident_matches = |x: &ArgLifetime| {
86                if let ArgLifetime::Existing(elt) = x {
87                    elt.ident == lt.ident
88                } else {
89                    false
90                }
91            };
92
93            if !self.lifetimes.iter().any(ident_matches) {
94                self.lifetimes.push(ArgLifetime::Existing(lt));
95            }
96        }
97    }
98}
99
100// Input:
101//     async fn f<S, T>(x : S, y : &T) -> Ret;
102//
103// Output:
104//     fn f<S, T>(x : S, y : &T) -> Pin<Box<dyn Future<Output = Ret> + Send>
105fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
106    // Determine the original return type
107    let ret = match &sig.output {
108        ReturnType::Default => quote!(()),
109        ReturnType::Type(_, ret) => quote!(#ret),
110    };
111
112    // Remove the asyncness of this function
113    sig.asyncness = None;
114
115    // Find and update any references in the input arguments
116    let mut v = ReferenceVisitor::default();
117    for input in &mut sig.inputs {
118        v.visit_fn_arg_mut(input);
119    }
120
121    // Does this expansion require `async_recursion to be added to the output?
122    let mut requires_lifetime = false;
123    let mut where_clause_lifetimes = vec![];
124    let mut where_clause_generics = vec![];
125
126    // 'async_recursion lifetime
127    let asr: Lifetime = parse_quote!('async_recursion);
128
129    // Add an S : 'async_recursion bound to any generic parameter
130    for param in sig.generics.type_params() {
131        let ident = param.ident.clone();
132        where_clause_generics.push(ident);
133        requires_lifetime = true;
134    }
135
136    // Add an 'a : 'async_recursion bound to any lifetimes 'a appearing in the function
137    if !v.lifetimes.is_empty() {
138        requires_lifetime = true;
139        for alt in v.lifetimes {
140            if let ArgLifetime::New(lt) = &alt {
141                // If this is a new argument,
142                sig.generics.params.push(parse_quote!(#lt));
143            }
144
145            // Add a bound to the where clause
146            let lt = alt.lifetime();
147            where_clause_lifetimes.push(lt);
148        }
149    }
150
151    // If our function accepts &self, then we modify this to the explicit lifetime &'life_self,
152    // and add the bound &'life_self : 'async_recursion
153    if v.self_receiver {
154        if v.self_receiver_new_lifetime {
155            sig.generics.params.push(parse_quote!('life_self));
156        }
157        where_clause_lifetimes.extend(v.self_lifetime);
158        requires_lifetime = true;
159    }
160
161    let box_lifetime: TokenStream = if requires_lifetime {
162        // Add 'async_recursion to our generic parameters
163        sig.generics.params.push(parse_quote!('async_recursion));
164
165        quote!(+ #asr)
166    } else {
167        quote!()
168    };
169
170    let send_bound: TokenStream = if args.send_bound {
171        quote!(+ ::core::marker::Send)
172    } else {
173        quote!()
174    };
175
176    let sync_bound: TokenStream = if args.sync_bound {
177        quote!(+ ::core::marker::Sync)
178    } else {
179        quote!()
180    };
181
182    let where_clause = sig
183        .generics
184        .where_clause
185        .get_or_insert_with(|| WhereClause {
186            where_token: Default::default(),
187            predicates: Punctuated::new(),
188        });
189
190    // Add our S : 'async_recursion bounds
191    for generic_ident in where_clause_generics {
192        where_clause
193            .predicates
194            .push(parse_quote!(#generic_ident : #asr));
195    }
196
197    // Add our 'a : 'async_recursion bounds
198    for lifetime in where_clause_lifetimes {
199        where_clause.predicates.push(parse_quote!(#lifetime : #asr));
200    }
201
202    // Modify the return type
203    sig.output = parse_quote! {
204        -> ::core::pin::Pin<Box<
205            dyn ::core::future::Future<Output = #ret> #box_lifetime #send_bound #sync_bound>>
206    };
207}