palette_derive/cast/
array_cast.rs
1use proc_macro::TokenStream;
2use proc_macro2::Span;
3
4use quote::{quote, ToTokens};
5use syn::{
6 punctuated::Punctuated, token::Comma, Attribute, Data, DeriveInput, Fields, Meta, Path, Type,
7};
8
9use crate::meta::{self, FieldAttributes, IdentOrIndex, TypeItemAttributes};
10use crate::util;
11
12pub fn derive(tokens: TokenStream) -> std::result::Result<TokenStream, Vec<syn::Error>> {
13 let DeriveInput {
14 ident,
15 attrs,
16 generics,
17 data,
18 ..
19 } = syn::parse(tokens).map_err(|error| vec![error])?;
20
21 let allowed_repr = is_allowed_repr(&attrs)?;
22 let (item_meta, item_errors) = meta::parse_namespaced_attributes::<TypeItemAttributes>(attrs);
23
24 let mut number_of_channels = 0usize;
25 let mut field_type: Option<Type> = None;
26
27 let (all_fields, fields_meta, field_errors) = match data {
28 Data::Struct(struct_item) => {
29 let (fields_meta, field_errors) =
30 meta::parse_field_attributes::<FieldAttributes>(struct_item.fields.clone());
31 let all_fields = match struct_item.fields {
32 Fields::Named(fields) => fields.named,
33 Fields::Unnamed(fields) => fields.unnamed,
34 Fields::Unit => Default::default(),
35 };
36
37 (all_fields, fields_meta, field_errors)
38 }
39 Data::Enum(_) => {
40 return Err(vec![syn::Error::new(
41 Span::call_site(),
42 "`ArrayCast` cannot be derived for enums, because of the discriminant",
43 )]);
44 }
45 Data::Union(_) => {
46 return Err(vec![syn::Error::new(
47 Span::call_site(),
48 "`ArrayCast` cannot be derived for unions",
49 )]);
50 }
51 };
52
53 let fields = all_fields
54 .into_iter()
55 .enumerate()
56 .map(|(index, field)| {
57 (
58 field
59 .ident
60 .map(IdentOrIndex::Ident)
61 .unwrap_or_else(|| IdentOrIndex::Index(index.into())),
62 field.ty,
63 )
64 })
65 .filter(|(field, _)| !fields_meta.zero_size_fields.contains(field));
66
67 let mut errors = Vec::new();
68
69 for (field, ty) in fields {
70 let ty = fields_meta
71 .type_substitutes
72 .get(&field)
73 .cloned()
74 .unwrap_or(ty);
75 number_of_channels += 1;
76
77 if let Some(field_type) = field_type.clone() {
78 if field_type != ty {
79 errors.push(syn::Error::new_spanned(
80 &field,
81 format!(
82 "expected fields to have type `{}`",
83 field_type.into_token_stream()
84 ),
85 ));
86 }
87 } else {
88 field_type = Some(ty);
89 }
90 }
91
92 if !allowed_repr {
93 errors.push(syn::Error::new(
94 Span::call_site(),
95 format!(
96 "a `#[repr(C)]` or `#[repr(transparent)]` attribute is required to give `{}` a fixed memory layout",
97 ident
98 ),
99 ));
100 }
101
102 let array_cast_trait_path = util::path(["cast", "ArrayCast"], item_meta.internal);
103
104 let mut implementation = if let Some(field_type) = field_type {
105 let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
106
107 quote! {
108 #[automatically_derived]
109 unsafe impl #impl_generics #array_cast_trait_path for #ident #type_generics #where_clause {
110 type Array = [#field_type; #number_of_channels];
111 }
112 }
113 } else {
114 errors.push(syn::Error::new(
115 Span::call_site(),
116 "`ArrayCast` can only be derived for structs with one or more fields".to_string(),
117 ));
118
119 return Err(errors);
120 };
121
122 implementation.extend(errors.iter().map(syn::Error::to_compile_error));
123
124 let item_errors = item_errors
125 .into_iter()
126 .map(|error| error.into_compile_error());
127 let field_errors = field_errors
128 .into_iter()
129 .map(|error| error.into_compile_error());
130
131 Ok(quote! {
132 #(#item_errors)*
133 #(#field_errors)*
134
135 #implementation
136 }
137 .into())
138}
139
140fn is_allowed_repr(attributes: &[Attribute]) -> std::result::Result<bool, Vec<syn::Error>> {
141 let mut errors = Vec::new();
142
143 for attribute in attributes {
144 let attribute_name = attribute.path().get_ident().map(ToString::to_string);
145
146 if let Some("repr") = attribute_name.as_deref() {
147 let meta_list = match attribute.meta.require_list() {
148 Ok(list) => list,
149 Err(error) => {
150 errors.push(error);
151 continue;
152 }
153 };
154
155 let items = match meta_list.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated)
156 {
157 Ok(items) => items,
158 Err(error) => {
159 errors.push(error);
160 continue;
161 }
162 };
163
164 let contains_allowed_repr = items.iter().any(|item| {
165 item.require_path_only()
166 .ok()
167 .and_then(Path::get_ident)
168 .map_or(false, |ident| ident == "C" || ident == "transparent")
169 });
170
171 if contains_allowed_repr {
172 return Ok(true);
173 }
174 }
175 }
176
177 if errors.is_empty() {
178 Ok(false)
179 } else {
180 Err(errors)
181 }
182}