palette_derive/cast/
array_cast.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3
4use quote::{quote, ToTokens};
5use syn::{
6    punctuated::Punctuated, token::Comma, Attribute, Data, DeriveInput, Fields, Meta, Path, Type,
7};
8
9use crate::meta::{self, FieldAttributes, IdentOrIndex, TypeItemAttributes};
10use crate::util;
11
12pub fn derive(tokens: TokenStream) -> std::result::Result<TokenStream, Vec<syn::Error>> {
13    let DeriveInput {
14        ident,
15        attrs,
16        generics,
17        data,
18        ..
19    } = syn::parse(tokens).map_err(|error| vec![error])?;
20
21    let allowed_repr = is_allowed_repr(&attrs)?;
22    let (item_meta, item_errors) = meta::parse_namespaced_attributes::<TypeItemAttributes>(attrs);
23
24    let mut number_of_channels = 0usize;
25    let mut field_type: Option<Type> = None;
26
27    let (all_fields, fields_meta, field_errors) = match data {
28        Data::Struct(struct_item) => {
29            let (fields_meta, field_errors) =
30                meta::parse_field_attributes::<FieldAttributes>(struct_item.fields.clone());
31            let all_fields = match struct_item.fields {
32                Fields::Named(fields) => fields.named,
33                Fields::Unnamed(fields) => fields.unnamed,
34                Fields::Unit => Default::default(),
35            };
36
37            (all_fields, fields_meta, field_errors)
38        }
39        Data::Enum(_) => {
40            return Err(vec![syn::Error::new(
41                Span::call_site(),
42                "`ArrayCast` cannot be derived for enums, because of the discriminant",
43            )]);
44        }
45        Data::Union(_) => {
46            return Err(vec![syn::Error::new(
47                Span::call_site(),
48                "`ArrayCast` cannot be derived for unions",
49            )]);
50        }
51    };
52
53    let fields = all_fields
54        .into_iter()
55        .enumerate()
56        .map(|(index, field)| {
57            (
58                field
59                    .ident
60                    .map(IdentOrIndex::Ident)
61                    .unwrap_or_else(|| IdentOrIndex::Index(index.into())),
62                field.ty,
63            )
64        })
65        .filter(|(field, _)| !fields_meta.zero_size_fields.contains(field));
66
67    let mut errors = Vec::new();
68
69    for (field, ty) in fields {
70        let ty = fields_meta
71            .type_substitutes
72            .get(&field)
73            .cloned()
74            .unwrap_or(ty);
75        number_of_channels += 1;
76
77        if let Some(field_type) = field_type.clone() {
78            if field_type != ty {
79                errors.push(syn::Error::new_spanned(
80                    &field,
81                    format!(
82                        "expected fields to have type `{}`",
83                        field_type.into_token_stream()
84                    ),
85                ));
86            }
87        } else {
88            field_type = Some(ty);
89        }
90    }
91
92    if !allowed_repr {
93        errors.push(syn::Error::new(
94            Span::call_site(),
95            format!(
96                "a `#[repr(C)]` or `#[repr(transparent)]` attribute is required to give `{}` a fixed memory layout",
97                ident
98            ),
99        ));
100    }
101
102    let array_cast_trait_path = util::path(["cast", "ArrayCast"], item_meta.internal);
103
104    let mut implementation = if let Some(field_type) = field_type {
105        let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
106
107        quote! {
108            #[automatically_derived]
109            unsafe impl #impl_generics #array_cast_trait_path for #ident #type_generics #where_clause {
110                type Array = [#field_type; #number_of_channels];
111            }
112        }
113    } else {
114        errors.push(syn::Error::new(
115            Span::call_site(),
116            "`ArrayCast` can only be derived for structs with one or more fields".to_string(),
117        ));
118
119        return Err(errors);
120    };
121
122    implementation.extend(errors.iter().map(syn::Error::to_compile_error));
123
124    let item_errors = item_errors
125        .into_iter()
126        .map(|error| error.into_compile_error());
127    let field_errors = field_errors
128        .into_iter()
129        .map(|error| error.into_compile_error());
130
131    Ok(quote! {
132        #(#item_errors)*
133        #(#field_errors)*
134
135        #implementation
136    }
137    .into())
138}
139
140fn is_allowed_repr(attributes: &[Attribute]) -> std::result::Result<bool, Vec<syn::Error>> {
141    let mut errors = Vec::new();
142
143    for attribute in attributes {
144        let attribute_name = attribute.path().get_ident().map(ToString::to_string);
145
146        if let Some("repr") = attribute_name.as_deref() {
147            let meta_list = match attribute.meta.require_list() {
148                Ok(list) => list,
149                Err(error) => {
150                    errors.push(error);
151                    continue;
152                }
153            };
154
155            let items = match meta_list.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated)
156            {
157                Ok(items) => items,
158                Err(error) => {
159                    errors.push(error);
160                    continue;
161                }
162            };
163
164            let contains_allowed_repr = items.iter().any(|item| {
165                item.require_path_only()
166                    .ok()
167                    .and_then(Path::get_ident)
168                    .map_or(false, |ident| ident == "C" || ident == "transparent")
169            });
170
171            if contains_allowed_repr {
172                return Ok(true);
173            }
174        }
175    }
176
177    if errors.is_empty() {
178        Ok(false)
179    } else {
180        Err(errors)
181    }
182}