zerovec_derive/
utils.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5use quote::{quote, ToTokens};
6
7use proc_macro2::Span;
8use proc_macro2::TokenStream as TokenStream2;
9use syn::parse::{Parse, ParseStream};
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{Attribute, Error, Field, Fields, Ident, Index, Result, Token};
13
14#[derive(Default)]
15pub struct ReprInfo {
16    pub c: bool,
17    pub transparent: bool,
18    pub u8: bool,
19    pub packed: bool,
20}
21
22impl ReprInfo {
23    pub fn compute(attrs: &[Attribute]) -> Self {
24        let mut info = ReprInfo::default();
25        for attr in attrs.iter().filter(|a| a.path().is_ident("repr")) {
26            if let Ok(pieces) = attr.parse_args::<IdentListAttribute>() {
27                for piece in pieces.idents.iter() {
28                    if piece == "C" || piece == "c" {
29                        info.c = true;
30                    } else if piece == "transparent" {
31                        info.transparent = true;
32                    } else if piece == "packed" {
33                        info.packed = true;
34                    } else if piece == "u8" {
35                        info.u8 = true;
36                    }
37                }
38            }
39        }
40        info
41    }
42
43    pub fn cpacked_or_transparent(self) -> bool {
44        (self.c && self.packed) || self.transparent
45    }
46}
47
48// An attribute that is a list of idents
49struct IdentListAttribute {
50    idents: Punctuated<Ident, Token![,]>,
51}
52
53impl Parse for IdentListAttribute {
54    fn parse(input: ParseStream) -> Result<Self> {
55        Ok(IdentListAttribute {
56            idents: input.parse_terminated(Ident::parse, Token![,])?,
57        })
58    }
59}
60
61/// Given a set of entries for struct field definitions to go inside a `struct {}` definition,
62/// wrap in a () or {} based on the type of field
63pub fn wrap_field_inits(streams: &[TokenStream2], fields: &Fields) -> TokenStream2 {
64    match *fields {
65        Fields::Named(_) => quote!( { #(#streams),* } ),
66        Fields::Unnamed(_) => quote!( ( #(#streams),* ) ),
67        Fields::Unit => {
68            unreachable!("#[make_(var)ule] should have already checked that there are fields")
69        }
70    }
71}
72
73/// Return a semicolon token if necessary after the struct definition
74pub fn semi_for(f: &Fields) -> TokenStream2 {
75    if let Fields::Unnamed(..) = *f {
76        quote!(;)
77    } else {
78        quote!()
79    }
80}
81
82/// Returns the repr attribute to be applied to the resultant ULE or VarULE type
83pub fn repr_for(f: &Fields) -> TokenStream2 {
84    if f.len() == 1 {
85        quote!(transparent)
86    } else {
87        quote!(C, packed)
88    }
89}
90
91fn suffixed_ident(name: &str, suffix: usize, s: Span) -> Ident {
92    Ident::new(&format!("{name}_{suffix}"), s)
93}
94
95/// Given an iterator over ULE or AsULE struct fields, returns code that calculates field sizes and generates a line
96/// of code per field based on the per_field_code function (whose parameters are the field, the identifier of the const
97/// for the previous offset, the identifier for the const for the next offset, and the field index)
98pub(crate) fn generate_per_field_offsets<'a>(
99    fields: &[FieldInfo<'a>],
100    // Whether the fields are ULE types or AsULE (and need conversion)
101    fields_are_asule: bool,
102    // (field, prev_offset_ident, size_ident)
103    mut per_field_code: impl FnMut(&FieldInfo<'a>, &Ident, &Ident) -> TokenStream2, /* (code, remaining_offset) */
104) -> (TokenStream2, syn::Ident) {
105    let mut prev_offset_ident = Ident::new("ZERO", Span::call_site());
106    let mut code = quote!(
107        const ZERO: usize = 0;
108    );
109
110    for (i, field_info) in fields.iter().enumerate() {
111        let field = &field_info.field;
112        let ty = &field.ty;
113        let ty = if fields_are_asule {
114            quote!(<#ty as zerovec::ule::AsULE>::ULE)
115        } else {
116            quote!(#ty)
117        };
118        let new_offset_ident = suffixed_ident("OFFSET", i, field.span());
119        let size_ident = suffixed_ident("SIZE", i, field.span());
120        let pf_code = per_field_code(field_info, &prev_offset_ident, &size_ident);
121        code = quote! {
122            #code;
123            const #size_ident: usize = ::core::mem::size_of::<#ty>();
124            const #new_offset_ident: usize = #prev_offset_ident + #size_ident;
125            #pf_code;
126        };
127
128        prev_offset_ident = new_offset_ident;
129    }
130
131    (code, prev_offset_ident)
132}
133
134#[derive(Clone, Debug)]
135pub(crate) struct FieldInfo<'a> {
136    pub accessor: TokenStream2,
137    pub field: &'a Field,
138    pub index: usize,
139}
140
141impl<'a> FieldInfo<'a> {
142    pub fn make_list(iter: impl Iterator<Item = &'a Field>) -> Vec<Self> {
143        iter.enumerate()
144            .map(|(i, field)| Self::new_for_field(field, i))
145            .collect()
146    }
147
148    pub fn new_for_field(f: &'a Field, index: usize) -> Self {
149        if let Some(ref i) = f.ident {
150            FieldInfo {
151                accessor: quote!(#i),
152                field: f,
153                index,
154            }
155        } else {
156            let idx = Index::from(index);
157            FieldInfo {
158                accessor: quote!(#idx),
159                field: f,
160                index,
161            }
162        }
163    }
164
165    /// Get the code for setting this field in struct decl/brace syntax
166    ///
167    /// Use self.accessor for dot-notation accesses
168    pub fn setter(&self) -> TokenStream2 {
169        if let Some(ref i) = self.field.ident {
170            quote!(#i: )
171        } else {
172            quote!()
173        }
174    }
175
176    /// Produce a name for a getter for the field
177    pub fn getter(&self) -> TokenStream2 {
178        if let Some(ref i) = self.field.ident {
179            quote!(#i)
180        } else {
181            suffixed_ident("field", self.index, self.field.span()).into_token_stream()
182        }
183    }
184
185    /// Produce a prose name for the field for use in docs
186    pub fn getter_doc_name(&self) -> String {
187        if let Some(ref i) = self.field.ident {
188            format!("the unsized `{i}` field")
189        } else {
190            format!("tuple struct field #{}", self.index)
191        }
192    }
193}
194
195/// Extracts all `zerovec::name(..)` attribute
196pub fn extract_parenthetical_zerovec_attrs(
197    attrs: &mut Vec<Attribute>,
198    name: &str,
199) -> Result<Vec<Ident>> {
200    let mut ret = vec![];
201    let mut error = None;
202    attrs.retain(|a| {
203        // skip the "zerovec" part
204        let second_segment = a.path().segments.iter().nth(1);
205
206        if let Some(second) = second_segment {
207            if second.ident == name {
208                let list = match a.parse_args::<IdentListAttribute>() {
209                    Ok(l) => l,
210                    Err(_) => {
211                        error = Some(Error::new(
212                            a.span(),
213                            format!("#[zerovec::{name}(..)] takes in a comma separated list of identifiers"),
214                        ));
215                        return false;
216                    }
217                };
218                ret.extend(list.idents.iter().cloned());
219                return false;
220            }
221        }
222
223        true
224    });
225
226    if let Some(error) = error {
227        return Err(error);
228    }
229    Ok(ret)
230}
231
232pub fn extract_single_tt_attr(
233    attrs: &mut Vec<Attribute>,
234    name: &str,
235) -> Result<Option<TokenStream2>> {
236    let mut ret = None;
237    let mut error = None;
238    attrs.retain(|a| {
239        // skip the "zerovec" part
240        let second_segment = a.path().segments.iter().nth(1);
241
242        if let Some(second) = second_segment {
243            if second.ident == name {
244                if ret.is_some() {
245                    error = Some(Error::new(
246                        a.span(),
247                        "Can only specify a single VarZeroVecFormat via #[zerovec::format(..)]",
248                    ));
249                    return false
250                }
251                ret = match a.parse_args::<TokenStream2>() {
252                    Ok(l) => Some(l),
253                    Err(_) => {
254                        error = Some(Error::new(
255                            a.span(),
256                            format!("#[zerovec::{name}(..)] takes in a comma separated list of identifiers"),
257                        ));
258                        return false;
259                    }
260                };
261                return false;
262            }
263        }
264
265        true
266    });
267
268    if let Some(error) = error {
269        return Err(error);
270    }
271    Ok(ret)
272}
273
274/// Removes all attributes with `zerovec` in the name and places them in a separate vector
275pub fn extract_zerovec_attributes(attrs: &mut Vec<Attribute>) -> Vec<Attribute> {
276    let mut ret = vec![];
277    attrs.retain(|a| {
278        if a.path().segments.len() == 2 && a.path().segments[0].ident == "zerovec" {
279            ret.push(a.clone());
280            return false;
281        }
282        true
283    });
284    ret
285}
286
287/// Extract attributes from field, and return them
288///
289/// Only current field attribute is `zerovec::varule(VarUleType)`
290pub fn extract_field_attributes(attrs: &mut Vec<Attribute>) -> Result<Option<Ident>> {
291    let mut zerovec_attrs = extract_zerovec_attributes(attrs);
292    let varule = extract_parenthetical_zerovec_attrs(&mut zerovec_attrs, "varule")?;
293
294    if varule.len() > 1 {
295        return Err(Error::new(
296            varule[1].span(),
297            "Found multiple #[zerovec::varule()] on one field",
298        ));
299    }
300
301    if !zerovec_attrs.is_empty() {
302        return Err(Error::new(
303            zerovec_attrs[1].span(),
304            "Found unusable #[zerovec::] attrs on field, only #[zerovec::varule()] supported",
305        ));
306    }
307
308    Ok(varule.first().cloned())
309}
310
311#[derive(Default, Clone)]
312pub struct ZeroVecAttrs {
313    pub skip_kv: bool,
314    pub skip_ord: bool,
315    pub skip_toowned: bool,
316    pub skip_from: bool,
317    pub serialize: bool,
318    pub deserialize: bool,
319    pub debug: bool,
320    pub hash: bool,
321    pub vzv_format: Option<TokenStream2>,
322}
323
324/// Removes all known zerovec:: attributes from struct attrs and validates them
325pub fn extract_attributes_common(
326    attrs: &mut Vec<Attribute>,
327    span: Span,
328    is_var: bool,
329) -> Result<ZeroVecAttrs> {
330    let mut zerovec_attrs = extract_zerovec_attributes(attrs);
331
332    let derive = extract_parenthetical_zerovec_attrs(&mut zerovec_attrs, "derive")?;
333    let skip = extract_parenthetical_zerovec_attrs(&mut zerovec_attrs, "skip_derive")?;
334    let format = extract_single_tt_attr(&mut zerovec_attrs, "format")?;
335
336    let name = if is_var { "make_varule" } else { "make_ule" };
337
338    if let Some(attr) = zerovec_attrs.first() {
339        return Err(Error::new(
340            attr.span(),
341            format!("Found unknown or duplicate attribute for #[{name}]"),
342        ));
343    }
344
345    let mut attrs = ZeroVecAttrs::default();
346
347    for ident in derive {
348        if ident == "Serialize" {
349            attrs.serialize = true;
350        } else if ident == "Deserialize" {
351            attrs.deserialize = true;
352        } else if ident == "Debug" {
353            attrs.debug = true;
354        } else if ident == "Hash" {
355            attrs.hash = true;
356        } else {
357            return Err(Error::new(
358                ident.span(),
359                format!(
360                    "Found unknown derive attribute for #[{name}]: #[zerovec::derive({ident})]"
361                ),
362            ));
363        }
364    }
365
366    for ident in skip {
367        if ident == "ZeroMapKV" {
368            attrs.skip_kv = true;
369        } else if ident == "Ord" {
370            attrs.skip_ord = true;
371        } else if ident == "ToOwned" && is_var {
372            attrs.skip_toowned = true;
373        } else if ident == "From" && is_var {
374            attrs.skip_from = true;
375        } else {
376            return Err(Error::new(
377                ident.span(),
378                format!("Found unknown derive attribute for #[{name}]: #[zerovec::skip_derive({ident})]"),
379            ));
380        }
381    }
382
383    if let Some(ref format) = format {
384        if !is_var {
385            return Err(Error::new(
386                format.span(),
387                format!(
388                    "Found unknown derive attribute for #[{name}]: #[zerovec::format({format})]"
389                ),
390            ));
391        }
392    }
393    attrs.vzv_format = format;
394
395    if (attrs.serialize || attrs.deserialize) && !is_var {
396        return Err(Error::new(
397            span,
398            "#[make_ule] does not support #[zerovec::derive(Serialize, Deserialize)]",
399        ));
400    }
401
402    Ok(attrs)
403}