auto_enums/
enum_derive.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use std::cell::Cell;
4
5use derive_utils::EnumData as Data;
6use proc_macro2::TokenStream;
7use quote::{quote, ToTokens as _};
8use syn::{
9    parse::{Parse, ParseStream},
10    parse_quote, Error, ItemEnum, Path, Result, Token,
11};
12
13pub(crate) fn attribute(args: TokenStream, input: TokenStream) -> TokenStream {
14    expand(args, input).unwrap_or_else(Error::into_compile_error)
15}
16
17#[derive(Default)]
18pub(crate) struct DeriveContext {
19    needs_pin_projection: Cell<bool>,
20}
21
22impl DeriveContext {
23    pub(crate) fn needs_pin_projection(&self) {
24        self.needs_pin_projection.set(true);
25    }
26}
27
28type DeriveFn = fn(&'_ DeriveContext, &'_ Data) -> Result<TokenStream>;
29
30fn get_derive(s: &str) -> Option<DeriveFn> {
31    macro_rules! match_derive {
32        ($($(#[$meta:meta])* $($arm:ident)::*,)*) => {$(
33            $(#[$meta])*
34            {
35                if crate::derive::$($arm)::*::NAME.iter().any(|name| *name == s) {
36                    return Some(crate::derive::$($arm)::*::derive)
37                }
38            }
39        )*};
40    }
41
42    match_derive! {
43        // core
44        #[cfg(feature = "convert")]
45        core::convert::as_mut,
46        #[cfg(feature = "convert")]
47        core::convert::as_ref,
48        core::fmt::debug,
49        core::fmt::display,
50        #[cfg(feature = "fmt")]
51        core::fmt::pointer,
52        #[cfg(feature = "fmt")]
53        core::fmt::binary,
54        #[cfg(feature = "fmt")]
55        core::fmt::octal,
56        #[cfg(feature = "fmt")]
57        core::fmt::upper_hex,
58        #[cfg(feature = "fmt")]
59        core::fmt::lower_hex,
60        #[cfg(feature = "fmt")]
61        core::fmt::upper_exp,
62        #[cfg(feature = "fmt")]
63        core::fmt::lower_exp,
64        core::fmt::write,
65        core::iter::iterator,
66        core::iter::double_ended_iterator,
67        core::iter::exact_size_iterator,
68        core::iter::fused_iterator,
69        #[cfg(feature = "trusted_len")]
70        core::iter::trusted_len,
71        core::iter::extend,
72        #[cfg(feature = "ops")]
73        core::ops::deref,
74        #[cfg(feature = "ops")]
75        core::ops::deref_mut,
76        #[cfg(feature = "ops")]
77        core::ops::index,
78        #[cfg(feature = "ops")]
79        core::ops::index_mut,
80        #[cfg(feature = "ops")]
81        core::ops::range_bounds,
82        #[cfg(feature = "fn_traits")]
83        core::ops::fn_,
84        #[cfg(feature = "fn_traits")]
85        core::ops::fn_mut,
86        #[cfg(feature = "fn_traits")]
87        core::ops::fn_once,
88        #[cfg(feature = "coroutine_trait")]
89        core::ops::coroutine,
90        core::future,
91        // std
92        #[cfg(feature = "std")]
93        std::io::read,
94        #[cfg(feature = "std")]
95        std::io::buf_read,
96        #[cfg(feature = "std")]
97        std::io::seek,
98        #[cfg(feature = "std")]
99        std::io::write,
100        #[cfg(feature = "std")]
101        std::error,
102        // type impls
103        #[cfg(feature = "transpose_methods")]
104        ty_impls::transpose,
105        // futures03
106        #[cfg(feature = "futures03")]
107        external::futures03::stream,
108        #[cfg(feature = "futures03")]
109        external::futures03::sink,
110        #[cfg(feature = "futures03")]
111        external::futures03::async_read,
112        #[cfg(feature = "futures03")]
113        external::futures03::async_write,
114        #[cfg(feature = "futures03")]
115        external::futures03::async_seek,
116        #[cfg(feature = "futures03")]
117        external::futures03::async_buf_read,
118        // futures01
119        #[cfg(feature = "futures01")]
120        external::futures01::future,
121        #[cfg(feature = "futures01")]
122        external::futures01::stream,
123        #[cfg(feature = "futures01")]
124        external::futures01::sink,
125        // rayon
126        #[cfg(feature = "rayon")]
127        external::rayon::par_iter,
128        #[cfg(feature = "rayon")]
129        external::rayon::indexed_par_iter,
130        #[cfg(feature = "rayon")]
131        external::rayon::par_extend,
132        // serde
133        #[cfg(feature = "serde")]
134        external::serde::serialize,
135        // tokio1
136        #[cfg(feature = "tokio1")]
137        external::tokio1::async_read,
138        #[cfg(feature = "tokio1")]
139        external::tokio1::async_write,
140        #[cfg(feature = "tokio1")]
141        external::tokio1::async_seek,
142        #[cfg(feature = "tokio1")]
143        external::tokio1::async_buf_read,
144        // tokio03
145        #[cfg(feature = "tokio03")]
146        external::tokio03::async_read,
147        #[cfg(feature = "tokio03")]
148        external::tokio03::async_write,
149        #[cfg(feature = "tokio03")]
150        external::tokio03::async_seek,
151        #[cfg(feature = "tokio03")]
152        external::tokio03::async_buf_read,
153        // tokio02
154        #[cfg(feature = "tokio02")]
155        external::tokio02::async_read,
156        #[cfg(feature = "tokio02")]
157        external::tokio02::async_write,
158        #[cfg(feature = "tokio02")]
159        external::tokio02::async_seek,
160        #[cfg(feature = "tokio02")]
161        external::tokio02::async_buf_read,
162        // tokio01
163        #[cfg(feature = "tokio01")]
164        external::tokio01::async_read,
165        #[cfg(feature = "tokio01")]
166        external::tokio01::async_write,
167        // http_body1
168        #[cfg(feature = "http_body1")]
169        external::http_body1::body,
170    }
171
172    None
173}
174
175struct Args {
176    inner: Vec<(String, Path)>,
177}
178
179impl Parse for Args {
180    fn parse(input: ParseStream<'_>) -> Result<Self> {
181        fn to_trimmed_string(p: &Path) -> String {
182            p.to_token_stream().to_string().replace(' ', "")
183        }
184
185        let mut inner = vec![];
186        while !input.is_empty() {
187            let path = input.parse()?;
188            inner.push((to_trimmed_string(&path), path));
189
190            if input.is_empty() {
191                break;
192            }
193            let _: Token![,] = input.parse()?;
194        }
195
196        Ok(Self { inner })
197    }
198}
199
200fn get_trait_deps(s: &str) -> Option<&'static [&'static str]> {
201    Some(match s {
202        "Copy" => &["Clone"],
203        "Eq" | "PartialOrd" => &["PartialEq"],
204        "Ord" => &["PartialOrd", "Eq", "PartialEq"],
205        #[cfg(feature = "ops")]
206        "DerefMut" => &["Deref"],
207        #[cfg(feature = "ops")]
208        "IndexMut" => &["Index"],
209        #[cfg(feature = "fn_traits")]
210        "Fn" => &["FnMut", "FnOnce"],
211        #[cfg(feature = "fn_traits")]
212        "FnMut" => &["FnOnce"],
213        "DoubleEndedIterator" | "ExactSizeIterator" | "FusedIterator" => &["Iterator"],
214        #[cfg(feature = "trusted_len")]
215        "TrustedLen" => &["Iterator"],
216        #[cfg(feature = "std")]
217        "BufRead" | "io::BufRead" => &["Read"],
218        #[cfg(feature = "std")]
219        "Error" => &["Display", "Debug"],
220        #[cfg(feature = "rayon")]
221        "rayon::IndexedParallelIterator" => &["rayon::ParallelIterator"],
222        _ => return None,
223    })
224}
225
226fn exists_alias(s: &str, v: &[(&str, Option<&Path>)]) -> bool {
227    fn get_alias(s: &str) -> Option<&'static str> {
228        macro_rules! match_alias {
229            ($($(#[$meta:meta])* $($arm:ident)::*,)*) => {$(
230                $(#[$meta])*
231                {
232                    if s == crate::derive::$($arm)::*::NAME[0] {
233                        return Some(crate::derive::$($arm)::*::NAME[1]);
234                    } else if s == crate::derive::$($arm)::*::NAME[1] {
235                        return Some(crate::derive::$($arm)::*::NAME[0]);
236                    }
237                }
238            )*};
239        }
240
241        match_alias! {
242            // core
243            core::fmt::debug,
244            core::fmt::display,
245            // std
246            #[cfg(feature = "std")]
247            std::io::read,
248            #[cfg(feature = "std")]
249            std::io::buf_read,
250            #[cfg(feature = "std")]
251            std::io::seek,
252            #[cfg(feature = "std")]
253            std::io::write,
254        }
255
256        None
257    }
258
259    get_alias(s).map_or(false, |x| v.iter().any(|(s, _)| *s == x))
260}
261
262fn expand(args: TokenStream, input: TokenStream) -> Result<TokenStream> {
263    let data = syn::parse2::<Data>(input)?;
264    let args = syn::parse2::<Args>(args)?.inner;
265    let args = args.iter().fold(vec![], |mut v, (s, arg)| {
266        if let Some(traits) = get_trait_deps(s) {
267            for s in traits.iter().filter(|&x| !args.iter().any(|(s, _)| s == x)) {
268                if !exists_alias(s, &v) {
269                    v.push((s, None));
270                }
271            }
272        }
273        if !exists_alias(s, &v) {
274            v.push((s, Some(arg)));
275        }
276        v
277    });
278
279    let mut derive = vec![];
280    let mut items = TokenStream::new();
281    let cx = DeriveContext::default();
282    for (s, arg) in args {
283        match (get_derive(s), arg) {
284            (Some(f), _) => {
285                items.extend(
286                    f(&cx, &data).map_err(|e| format_err!(data, "`enum_derive({})` {}", s, e))?,
287                );
288            }
289            (_, Some(arg)) => derive.push(arg),
290            _ => {}
291        }
292    }
293
294    let mut item = if cx.needs_pin_projection.get() {
295        // If a user creates their own Unpin or Drop implementation, trait implementations with
296        // `Pin<&mut self>` receiver can cause unsoundness.
297        //
298        // This was not a problem in #[auto_enum] attribute where enums are anonymized,
299        // but it becomes a problem when users have access to enums (i.e., when using #[enum_derive]).
300        //
301        // So, we ensure safety here by an Unpin implementation that implements Unpin
302        // only if all fields are Unpin (this also forbids custom Unpin implementation),
303        // and a hack that forbids custom Drop implementation. (Both are what pin-project does by default.)
304        // The repr(packed) check is not needed since repr(packed) is not available on enum.
305
306        // Automatically create the appropriate conditional `Unpin` implementation.
307        // https://github.com/taiki-e/pin-project/blob/v1.1.5/examples/struct-default-expanded.rs#L98
308        // TODO: use https://github.com/taiki-e/pin-project/issues/102#issuecomment-540472282's trick.
309        // TODO: https://github.com/taiki-e/pin-project/pull/357 is also needed?
310        items.extend(derive_utils::derive_trait(
311            &data,
312            &parse_quote!(::core::marker::Unpin),
313            None,
314            parse_quote! {
315                trait Unpin {}
316            },
317        ));
318
319        let item: ItemEnum = data.into();
320        let name = &item.ident;
321        let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
322        // Ensure that enum does not implement `Drop`.
323        // https://github.com/taiki-e/pin-project/blob/v1.1.5/examples/struct-default-expanded.rs#L147
324        items.extend(quote! {
325            const _: () = {
326                trait MustNotImplDrop {}
327                #[allow(clippy::drop_bounds, drop_bounds)]
328                #[automatically_derived]
329                impl<T: ::core::ops::Drop> MustNotImplDrop for T {}
330                #[automatically_derived]
331                impl #impl_generics MustNotImplDrop for #name #ty_generics #where_clause {}
332            };
333        });
334        item
335    } else {
336        data.into()
337    };
338
339    if !derive.is_empty() {
340        item.attrs.push(parse_quote!(#[derive(#(#derive),*)]));
341    }
342
343    let mut item = item.into_token_stream();
344    item.extend(items);
345    Ok(item)
346}