ouroboros_macro/
parse.rs

1use proc_macro2::{Span, TokenTree};
2use quote::format_ident;
3use syn::{
4    spanned::Spanned, Attribute, Error, Fields, GenericParam, ItemStruct, MacroDelimiter, Meta,
5};
6
7use crate::{
8    covariance_detection::type_is_covariant_over_this_lifetime,
9    info_structures::{BorrowRequest, Derive, FieldType, StructFieldInfo, StructInfo},
10    utils::submodule_contents_visibility,
11};
12
13fn handle_borrows_attr(
14    field_info: &mut [StructFieldInfo],
15    attr: &Attribute,
16    borrows: &mut Vec<BorrowRequest>,
17) -> Result<(), Error> {
18    let mut borrow_mut = false;
19    let mut waiting_for_comma = false;
20    let tokens = match &attr.meta {
21        Meta::List(ml) => ml.tokens.clone(),
22        _ => {
23            return Err(Error::new_spanned(
24                &attr.meta,
25                "Invalid syntax for borrows() macro.",
26            ))
27        }
28    };
29    for token in tokens {
30        if let TokenTree::Ident(ident) = token {
31            if waiting_for_comma {
32                return Err(Error::new_spanned(&ident, "Expected comma."));
33            }
34            let istr = ident.to_string();
35            if istr == "mut" {
36                if borrow_mut {
37                    return Err(Error::new_spanned(&ident, "Unexpected double 'mut'"));
38                }
39                borrow_mut = true;
40            } else {
41                let index = field_info.iter().position(|item| item.name == istr);
42                let index = if let Some(v) = index {
43                    v
44                } else {
45                    return Err(Error::new_spanned(
46                        &ident,
47                        concat!(
48                            "Unknown identifier, make sure that it is spelled ",
49                            "correctly and defined above the location it is borrowed."
50                        ),
51                    ));
52                };
53                if borrow_mut {
54                    if field_info[index].field_type == FieldType::Borrowed {
55                        return Err(Error::new_spanned(
56                            &ident,
57                            "Cannot borrow mutably, this field was previously borrowed immutably.",
58                        ));
59                    }
60                    if field_info[index].field_type == FieldType::BorrowedMut {
61                        return Err(Error::new_spanned(&ident, "Cannot borrow mutably twice."));
62                    }
63                    field_info[index].field_type = FieldType::BorrowedMut;
64                } else {
65                    if field_info[index].field_type == FieldType::BorrowedMut {
66                        return Err(Error::new_spanned(
67                            &ident,
68                            "Cannot borrow as immutable as it was previously borrowed mutably.",
69                        ));
70                    }
71                    field_info[index].field_type = FieldType::Borrowed;
72                }
73                borrows.push(BorrowRequest {
74                    index,
75                    mutable: borrow_mut,
76                });
77                waiting_for_comma = true;
78                borrow_mut = false;
79            }
80        } else if let TokenTree::Punct(punct) = token {
81            if punct.as_char() == ',' {
82                if waiting_for_comma {
83                    waiting_for_comma = false;
84                } else {
85                    return Err(Error::new_spanned(&punct, "Unexpected extra comma."));
86                }
87            } else {
88                return Err(Error::new_spanned(
89                    &punct,
90                    "Unexpected punctuation, expected comma or identifier.",
91                ));
92            }
93        } else {
94            return Err(Error::new_spanned(
95                &token,
96                "Unexpected token, expected comma or identifier.",
97            ));
98        }
99    }
100    Ok(())
101}
102
103fn parse_derive_token(token: &TokenTree) -> Result<Option<Derive>, Error> {
104    match token {
105        TokenTree::Ident(ident) => match &ident.to_string()[..] {
106            "Debug" => Ok(Some(Derive::Debug)),
107            "PartialEq" => Ok(Some(Derive::PartialEq)),
108            "Eq" => Ok(Some(Derive::Eq)),
109            _ => Err(Error::new(
110                ident.span(),
111                format!("{} cannot be derived for self-referencing structs", ident),
112            )),
113        },
114        TokenTree::Punct(..) => Ok(None),
115        _ => Err(Error::new(token.span(), "bad syntax")),
116    }
117}
118
119fn parse_derive_attribute(attr: &Attribute) -> Result<Vec<Derive>, Error> {
120    let body = match &attr.meta {
121        Meta::List(ml) => ml,
122        _ => unreachable!(),
123    };
124    if !matches!(body.delimiter, MacroDelimiter::Paren(_)) {
125        return Err(Error::new(
126            attr.span(),
127            format!(
128                "malformed derive input, derive attributes are of the form `#[derive({})]`",
129                body.tokens
130            ),
131        ));
132    }
133    let mut derives = Vec::new();
134    for token in body.tokens.clone().into_iter() {
135        if let Some(derive) = parse_derive_token(&token)? {
136            derives.push(derive);
137        }
138    }
139    Ok(derives)
140}
141
142pub fn parse_struct(def: &ItemStruct) -> Result<StructInfo, Error> {
143    let vis = def.vis.clone();
144    let generics = def.generics.clone();
145    let mut actual_struct_def = def.clone();
146    actual_struct_def.vis = vis.clone();
147    let mut fields = Vec::new();
148    match &mut actual_struct_def.fields {
149        Fields::Named(def_fields) => {
150            for field in &mut def_fields.named {
151                let mut borrows = Vec::new();
152                let mut self_referencing = false;
153                let mut covariant = type_is_covariant_over_this_lifetime(&field.ty);
154                let mut remove_attrs = Vec::new();
155                for (index, attr) in field.attrs.iter().enumerate() {
156                    let path = &attr.path();
157                    if path.leading_colon.is_some() {
158                        continue;
159                    }
160                    if path.segments.len() != 1 {
161                        continue;
162                    }
163                    if path.segments.first().unwrap().ident == "borrows" {
164                        if self_referencing {
165                            panic!("TODO: Nice error, used #[borrows()] twice.");
166                        }
167                        self_referencing = true;
168                        handle_borrows_attr(&mut fields[..], attr, &mut borrows)?;
169                        remove_attrs.push(index);
170                    }
171                    if path.segments.first().unwrap().ident == "covariant" {
172                        if covariant.is_some() {
173                            panic!("TODO: Nice error, covariance specified twice.");
174                        }
175                        covariant = Some(true);
176                        remove_attrs.push(index);
177                    }
178                    if path.segments.first().unwrap().ident == "not_covariant" {
179                        if covariant.is_some() {
180                            panic!("TODO: Nice error, covariance specified twice.");
181                        }
182                        covariant = Some(false);
183                        remove_attrs.push(index);
184                    }
185                }
186                // We should not be able to access the field outside of the hidden module where
187                // everything is generated.
188                let with_vis = submodule_contents_visibility(&field.vis.clone());
189                fields.push(StructFieldInfo {
190                    name: field.ident.clone().expect("Named field has no name."),
191                    typ: field.ty.clone(),
192                    field_type: FieldType::Tail,
193                    vis: with_vis,
194                    borrows,
195                    self_referencing,
196                    covariant,
197                });
198            }
199        }
200        Fields::Unnamed(_fields) => {
201            return Err(Error::new(
202                Span::call_site(),
203                "Tuple structs are not supported yet.",
204            ))
205        }
206        Fields::Unit => {
207            return Err(Error::new(
208                Span::call_site(),
209                "Unit structs cannot be self-referential.",
210            ))
211        }
212    }
213    if fields.len() < 2 {
214        return Err(Error::new(
215            Span::call_site(),
216            "Self-referencing structs must have at least 2 fields.",
217        ));
218    }
219    let mut has_non_tail = false;
220    for field in &fields {
221        if !field.field_type.is_tail() {
222            has_non_tail = true;
223            break;
224        }
225    }
226    if !has_non_tail {
227        return Err(Error::new(
228            Span::call_site(),
229            format!(
230                concat!(
231                    "Self-referencing struct cannot be made entirely of tail fields, try adding ",
232                    "#[borrows({0})] to a field defined after {0}."
233                ),
234                fields[0].name
235            ),
236        ));
237    }
238    let first_lifetime = if let Some(GenericParam::Lifetime(param)) = generics.params.first() {
239        param.lifetime.ident.clone()
240    } else {
241        format_ident!("static")
242    };
243    let mut attributes = Vec::new();
244    let mut derives = Vec::new();
245    for attr in &def.attrs {
246        let p = &attr.path().segments;
247        if p.is_empty() {
248            return Err(Error::new(p.span(), "Unsupported attribute".to_string()));
249        }
250        let name = p[0].ident.to_string();
251        let good = matches!(&name[..], "clippy" | "allow" | "deny" | "doc");
252        if good {
253            attributes.push(attr.clone())
254        } else if name == "derive" {
255            if !derives.is_empty() {
256                return Err(Error::new(
257                    attr.span(),
258                    "Multiple derive attributes not allowed",
259                ));
260            } else {
261                derives = parse_derive_attribute(attr)?;
262            }
263        } else {
264            return Err(Error::new(p.span(), "Unsupported attribute".to_string()));
265        }
266    }
267
268    Ok(StructInfo {
269        derives,
270        ident: def.ident.clone(),
271        internal_ident: format_ident!("{}Internal", def.ident),
272        generics: def.generics.clone(),
273        fields,
274        vis,
275        first_lifetime,
276        attributes,
277    })
278}