derive_utils/
parse.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use core::mem;
4use std::borrow::Cow;
5
6use proc_macro2::{TokenStream, TokenTree};
7use quote::{quote, ToTokens as _};
8use syn::{
9    parse_quote, token, Block, FnArg, GenericParam, Generics, Ident, ImplItem, ImplItemFn,
10    ItemImpl, ItemTrait, Path, Signature, Stmt, Token, TraitItem, TraitItemFn, TraitItemType, Type,
11    TypeParamBound, TypePath, Visibility, WherePredicate,
12};
13
14use crate::ast::EnumData;
15
16/// A function for creating `proc_macro_derive` like deriving trait to enum so
17/// long as all variants are implemented that trait.
18///
19/// # Examples
20///
21/// ```
22/// # extern crate proc_macro;
23/// use derive_utils::derive_trait;
24/// use proc_macro::TokenStream;
25/// use quote::format_ident;
26/// use syn::{parse_macro_input, parse_quote};
27///
28/// # #[cfg(any())]
29/// #[proc_macro_derive(Iterator)]
30/// # pub fn _derive_iterator(_: TokenStream) -> TokenStream { unimplemented!() }
31/// pub fn derive_iterator(input: TokenStream) -> TokenStream {
32///     derive_trait(
33///         &parse_macro_input!(input),
34///         // trait path
35///         &parse_quote!(std::iter::Iterator),
36///         // super trait's associated types
37///         None,
38///         // trait definition
39///         parse_quote! {
40///             trait Iterator {
41///                 type Item;
42///                 fn next(&mut self) -> Option<Self::Item>;
43///                 fn size_hint(&self) -> (usize, Option<usize>);
44///             }
45///         },
46///     )
47///     .into()
48/// }
49///
50/// # #[cfg(any())]
51/// #[proc_macro_derive(ExactSizeIterator)]
52/// # pub fn _derive_exact_size_iterator(_: TokenStream) -> TokenStream { unimplemented!() }
53/// pub fn derive_exact_size_iterator(input: TokenStream) -> TokenStream {
54///     derive_trait(
55///         &parse_macro_input!(input),
56///         // trait path
57///         &parse_quote!(std::iter::ExactSizeIterator),
58///         // super trait's associated types
59///         Some(format_ident!("Item")),
60///         // trait definition
61///         parse_quote! {
62///             trait ExactSizeIterator: Iterator {
63///                 fn len(&self) -> usize;
64///             }
65///         },
66///     )
67///     .into()
68/// }
69/// ```
70pub fn derive_trait<I>(
71    data: &EnumData,
72    trait_path: &Path,
73    supertraits_types: I,
74    trait_def: ItemTrait,
75) -> TokenStream
76where
77    I: IntoIterator<Item = Ident>,
78    I::IntoIter: ExactSizeIterator,
79{
80    EnumImpl::from_trait(data, trait_path, supertraits_types, trait_def).build()
81}
82
83/// A builder for implementing a trait for enums.
84pub struct EnumImpl<'a> {
85    data: &'a EnumData,
86    defaultness: bool,
87    unsafety: bool,
88    generics: Generics,
89    trait_: Option<Path>,
90    self_ty: Box<Type>,
91    items: Vec<ImplItem>,
92}
93
94impl<'a> EnumImpl<'a> {
95    /// Creates a new `EnumImpl`.
96    pub fn new(data: &'a EnumData) -> Self {
97        let ident = &data.ident;
98        let ty_generics = data.generics.split_for_impl().1;
99        Self {
100            data,
101            defaultness: false,
102            unsafety: false,
103            generics: data.generics.clone(),
104            trait_: None,
105            self_ty: Box::new(parse_quote!(#ident #ty_generics)),
106            items: vec![],
107        }
108    }
109
110    /// Creates a new `EnumImpl` from a trait definition.
111    ///
112    /// The following items are ignored:
113    /// - Generic associated types (GAT) ([`TraitItem::Type`] that has generics)
114    /// - [`TraitItem::Const`]
115    /// - [`TraitItem::Macro`]
116    /// - [`TraitItem::Verbatim`]
117    ///
118    /// # Panics
119    ///
120    /// Panics if a trait method has a body, no receiver, or a receiver other
121    /// than the following:
122    ///
123    /// - `&self`
124    /// - `&mut self`
125    /// - `self`
126    pub fn from_trait<I>(
127        data: &'a EnumData,
128        trait_path: &Path,
129        supertraits_types: I,
130        mut trait_def: ItemTrait,
131    ) -> Self
132    where
133        I: IntoIterator<Item = Ident>,
134        I::IntoIter: ExactSizeIterator,
135    {
136        let mut generics = data.generics.clone();
137        let trait_ = {
138            if trait_def.generics.params.is_empty() {
139                trait_path.clone()
140            } else {
141                let ty_generics = trait_def.generics.split_for_impl().1;
142                parse_quote!(#trait_path #ty_generics)
143            }
144        };
145
146        let fst = data.field_types().next().unwrap();
147        let mut types: Vec<_> = trait_def
148            .items
149            .iter()
150            .filter_map(|item| match item {
151                TraitItem::Type(ty) => Some((false, Cow::Borrowed(&ty.ident))),
152                _ => None,
153            })
154            .collect();
155
156        let supertraits_types = supertraits_types.into_iter();
157        if supertraits_types.len() > 0 {
158            if let Some(TypeParamBound::Trait(_)) = trait_def.supertraits.iter().next() {
159                types.extend(supertraits_types.map(|ident| (true, Cow::Owned(ident))));
160            }
161        }
162
163        // https://github.com/taiki-e/derive_utils/issues/47
164        let type_params = generics.type_params().map(|p| p.ident.to_string()).collect::<Vec<_>>();
165        let has_method = trait_def.items.iter().any(|i| matches!(i, TraitItem::Fn(..)));
166        if !has_method || !type_params.is_empty() {
167            struct HasTypeParam<'a>(&'a [String]);
168
169            impl HasTypeParam<'_> {
170                fn check_ident(&self, ident: &Ident) -> bool {
171                    let ident = ident.to_string();
172                    self.0.contains(&ident)
173                }
174
175                fn visit_type(&self, ty: &Type) -> bool {
176                    if let Type::Path(node) = ty {
177                        if node.qself.is_none() {
178                            if let Some(ident) = node.path.get_ident() {
179                                return self.check_ident(ident);
180                            }
181                        }
182                    }
183                    self.visit_token_stream(ty.to_token_stream())
184                }
185
186                fn visit_token_stream(&self, tokens: TokenStream) -> bool {
187                    for tt in tokens {
188                        match tt {
189                            TokenTree::Ident(ident) => {
190                                if self.check_ident(&ident) {
191                                    return true;
192                                }
193                            }
194                            TokenTree::Group(group) => {
195                                let content = group.stream();
196                                if self.visit_token_stream(content) {
197                                    return true;
198                                }
199                            }
200                            _ => {}
201                        }
202                    }
203                    false
204                }
205            }
206
207            let visitor = HasTypeParam(&type_params);
208            let where_clause = &mut generics.make_where_clause().predicates;
209            if !has_method || visitor.visit_type(fst) {
210                where_clause.push(parse_quote!(#fst: #trait_));
211            }
212            if data.field_types().len() > 1 {
213                let fst_tokens = fst.to_token_stream().to_string();
214                where_clause.extend(data.field_types().skip(1).filter_map(
215                    |variant| -> Option<WherePredicate> {
216                        if has_method && !visitor.visit_type(variant) {
217                            return None;
218                        }
219                        if variant.to_token_stream().to_string() == fst_tokens {
220                            return None;
221                        }
222                        if types.is_empty() {
223                            return Some(parse_quote!(#variant: #trait_));
224                        }
225                        let types = types.iter().map(|(supertraits, ident)| {
226                            match trait_def.supertraits.iter().next() {
227                                Some(TypeParamBound::Trait(trait_)) if *supertraits => {
228                                    quote!(#ident = <#fst as #trait_>::#ident)
229                                }
230                                _ => quote!(#ident = <#fst as #trait_>::#ident),
231                            }
232                        });
233                        if trait_def.generics.params.is_empty() {
234                            Some(parse_quote!(#variant: #trait_path<#(#types),*>))
235                        } else {
236                            let generics =
237                                trait_def.generics.params.iter().map(|param| match param {
238                                    GenericParam::Lifetime(def) => def.lifetime.to_token_stream(),
239                                    GenericParam::Type(param) => param.ident.to_token_stream(),
240                                    GenericParam::Const(param) => param.ident.to_token_stream(),
241                                });
242                            Some(parse_quote!(#variant: #trait_path<#(#generics),*, #(#types),*>))
243                        }
244                    },
245                ));
246            }
247        }
248
249        if !trait_def.generics.params.is_empty() {
250            generics.params.extend(mem::take(&mut trait_def.generics.params));
251        }
252
253        if let Some(old) = trait_def.generics.where_clause.as_mut() {
254            if !old.predicates.is_empty() {
255                generics.make_where_clause().predicates.extend(mem::take(&mut old.predicates));
256            }
257        }
258
259        let ident = &data.ident;
260        let ty_generics = data.generics.split_for_impl().1;
261        let mut impls = Self {
262            data,
263            defaultness: false,
264            unsafety: trait_def.unsafety.is_some(),
265            generics,
266            trait_: Some(trait_),
267            self_ty: Box::new(parse_quote!(#ident #ty_generics)),
268            items: Vec::with_capacity(trait_def.items.len()),
269        };
270        impls.append_items_from_trait(trait_def);
271        impls
272    }
273
274    pub fn set_trait(&mut self, path: Path) {
275        self.trait_ = Some(path);
276    }
277
278    /// Appends a generic type parameter to the back of generics.
279    pub fn push_generic_param(&mut self, param: GenericParam) {
280        self.generics.params.push(param);
281    }
282
283    /// Appends a predicate to the back of `where`-clause.
284    pub fn push_where_predicate(&mut self, predicate: WherePredicate) {
285        self.generics.make_where_clause().predicates.push(predicate);
286    }
287
288    /// Appends an item to impl items.
289    pub fn push_item(&mut self, item: ImplItem) {
290        self.items.push(item);
291    }
292
293    /// Appends a method to impl items.
294    ///
295    /// # Panics
296    ///
297    /// Panics if a trait method has a body, no receiver, or a receiver other
298    /// than the following:
299    ///
300    /// - `&self`
301    /// - `&mut self`
302    /// - `self`
303    pub fn push_method(&mut self, item: TraitItemFn) {
304        assert!(item.default.is_none(), "method `{}` has a body", item.sig.ident);
305
306        let self_ty = ReceiverKind::new(&item.sig);
307        let mut args = Vec::with_capacity(item.sig.inputs.len());
308        item.sig.inputs.iter().skip(1).for_each(|arg| match arg {
309            FnArg::Typed(arg) => args.push(&arg.pat),
310            FnArg::Receiver(_) => panic!(
311                "method `{}` has a receiver in a position other than the first argument",
312                item.sig.ident
313            ),
314        });
315
316        let method = &item.sig.ident;
317        let ident = &self.data.ident;
318        let method = match self_ty {
319            ReceiverKind::Normal => match &self.trait_ {
320                None => {
321                    let arms = self.data.variant_idents().map(|v| {
322                        quote! {
323                            #ident::#v(x) => x.#method(#(#args),*),
324                        }
325                    });
326                    parse_quote!(match self { #(#arms)* })
327                }
328                Some(trait_) => {
329                    let arms =
330                        self.data.variant_idents().zip(self.data.field_types()).map(|(v, ty)| {
331                            quote! {
332                                #ident::#v(x) => <#ty as #trait_>::#method(x #(,#args)*),
333                            }
334                        });
335                    parse_quote!(match self { #(#arms)* })
336                }
337            },
338        };
339
340        self.push_item(ImplItem::Fn(ImplItemFn {
341            attrs: item.attrs,
342            vis: Visibility::Inherited,
343            defaultness: None,
344            sig: item.sig,
345            block: Block {
346                brace_token: token::Brace::default(),
347                stmts: vec![Stmt::Expr(method, None)],
348            },
349        }));
350    }
351
352    /// Appends items from a trait definition to impl items.
353    ///
354    /// # Panics
355    ///
356    /// Panics if a trait method has a body, no receiver, or a receiver other
357    /// than the following:
358    ///
359    /// - `&self`
360    /// - `&mut self`
361    /// - `self`
362    pub fn append_items_from_trait(&mut self, trait_def: ItemTrait) {
363        let fst = self.data.field_types().next();
364        trait_def.items.into_iter().for_each(|item| match item {
365            // The TraitItemType::generics field (Generic associated types (GAT)) are not supported
366            TraitItem::Type(TraitItemType { ident, .. }) => {
367                let trait_ = &self.trait_;
368                let ty = parse_quote!(type #ident = <#fst as #trait_>::#ident;);
369                self.push_item(ImplItem::Type(ty));
370            }
371            TraitItem::Fn(method) => self.push_method(method),
372            _ => {}
373        });
374    }
375
376    pub fn build(self) -> TokenStream {
377        self.build_impl().to_token_stream()
378    }
379
380    pub fn build_impl(self) -> ItemImpl {
381        ItemImpl {
382            attrs: vec![parse_quote!(#[automatically_derived])],
383            defaultness: if self.defaultness { Some(<Token![default]>::default()) } else { None },
384            unsafety: if self.unsafety { Some(<Token![unsafe]>::default()) } else { None },
385            impl_token: token::Impl::default(),
386            generics: self.generics,
387            trait_: self.trait_.map(|trait_| (None, trait_, <Token![for]>::default())),
388            self_ty: self.self_ty,
389            brace_token: token::Brace::default(),
390            items: self.items,
391        }
392    }
393}
394
395enum ReceiverKind {
396    /// `&(mut) self`, `(mut) self`, `(mut) self: &(mut) Self`, or `(mut) self: Self`
397    Normal,
398}
399
400impl ReceiverKind {
401    fn new(sig: &Signature) -> Self {
402        fn get_ty_path(ty: &Type) -> Option<&Path> {
403            if let Type::Path(TypePath { qself: None, path }) = ty {
404                Some(path)
405            } else {
406                None
407            }
408        }
409
410        match sig.receiver() {
411            None => panic!("method `{}` has no receiver", sig.ident),
412            Some(receiver) => {
413                if receiver.colon_token.is_none() {
414                    return ReceiverKind::Normal;
415                }
416                match &*receiver.ty {
417                    Type::Path(TypePath { qself: None, path }) => {
418                        // (mut) self: Self
419                        if path.is_ident("Self") {
420                            return ReceiverKind::Normal;
421                        }
422                    }
423                    Type::Reference(ty) => {
424                        // (mut) self: &(mut) Self
425                        if get_ty_path(&ty.elem).map_or(false, |path| path.is_ident("Self")) {
426                            return ReceiverKind::Normal;
427                        }
428                    }
429                    _ => {}
430                }
431                panic!(
432                    "method `{}` has unsupported receiver type: {}",
433                    sig.ident,
434                    receiver.ty.to_token_stream()
435                );
436            }
437        }
438    }
439}