wayland_scanner/
common.rs

1use std::fmt::Write;
2
3use proc_macro2::{Ident, Literal, Span, TokenStream};
4
5use quote::{format_ident, quote, ToTokens};
6
7use crate::{protocol::*, util::*, Side};
8
9pub(crate) fn generate_enums_for(interface: &Interface) -> TokenStream {
10    interface.enums.iter().map(ToTokens::into_token_stream).collect()
11}
12
13impl ToTokens for Enum {
14    fn to_tokens(&self, tokens: &mut TokenStream) {
15        let enum_decl;
16        let enum_impl;
17
18        let doc_attr = self.description.as_ref().map(description_to_doc_attr);
19        let ident = Ident::new(&snake_to_camel(&self.name), Span::call_site());
20
21        if self.bitfield {
22            let entries = self.entries.iter().map(|entry| {
23                let doc_attr = entry
24                    .description
25                    .as_ref()
26                    .map(description_to_doc_attr)
27                    .or_else(|| entry.summary.as_ref().map(|s| to_doc_attr(s)));
28
29                let prefix = if entry.name.chars().next().unwrap().is_numeric() { "_" } else { "" };
30                let ident = format_ident!("{}{}", prefix, snake_to_camel(&entry.name));
31
32                let value = Literal::u32_unsuffixed(entry.value);
33
34                quote! {
35                    #doc_attr
36                    const #ident = #value;
37                }
38            });
39
40            enum_decl = quote! {
41                bitflags::bitflags! {
42                    #doc_attr
43                    #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
44                    pub struct #ident: u32 {
45                        #(#entries)*
46                    }
47                }
48            };
49            enum_impl = quote! {
50                impl std::convert::TryFrom<u32> for #ident {
51                    type Error = ();
52                    fn try_from(val: u32) -> Result<#ident, ()> {
53                        #ident::from_bits(val).ok_or(())
54                    }
55                }
56                impl std::convert::From<#ident> for u32 {
57                    fn from(val: #ident) -> u32 {
58                        val.bits()
59                    }
60                }
61            };
62        } else {
63            let variants = self.entries.iter().map(|entry| {
64                let doc_attr = entry
65                    .description
66                    .as_ref()
67                    .map(description_to_doc_attr)
68                    .or_else(|| entry.summary.as_ref().map(|s| to_doc_attr(s)));
69
70                let prefix = if entry.name.chars().next().unwrap().is_numeric() { "_" } else { "" };
71                let variant = format_ident!("{}{}", prefix, snake_to_camel(&entry.name));
72
73                let value = Literal::u32_unsuffixed(entry.value);
74
75                quote! {
76                    #doc_attr
77                    #variant = #value
78                }
79            });
80
81            enum_decl = quote! {
82                #doc_attr
83                #[repr(u32)]
84                #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
85                #[non_exhaustive]
86                pub enum #ident {
87                    #(#variants,)*
88                }
89            };
90
91            let match_arms = self.entries.iter().map(|entry| {
92                let value = Literal::u32_unsuffixed(entry.value);
93
94                let prefix = if entry.name.chars().next().unwrap().is_numeric() { "_" } else { "" };
95                let variant = format_ident!("{}{}", prefix, snake_to_camel(&entry.name));
96
97                quote! {
98                    #value => Ok(#ident::#variant)
99                }
100            });
101
102            enum_impl = quote! {
103                impl std::convert::TryFrom<u32> for #ident {
104                    type Error = ();
105                    fn try_from(val: u32) -> Result<#ident, ()> {
106                        match val {
107                            #(#match_arms,)*
108                            _ => Err(())
109                        }
110                    }
111                }
112                impl std::convert::From<#ident> for u32 {
113                    fn from(val: #ident) -> u32 {
114                        val as u32
115                    }
116                }
117            };
118        }
119
120        enum_decl.to_tokens(tokens);
121        enum_impl.to_tokens(tokens);
122    }
123}
124
125pub(crate) fn gen_msg_constants(requests: &[Message], events: &[Message]) -> TokenStream {
126    let req_constants = requests.iter().enumerate().map(|(opcode, msg)| {
127        let since_cstname = format_ident!("REQ_{}_SINCE", msg.name.to_ascii_uppercase());
128        let opcode_cstname = format_ident!("REQ_{}_OPCODE", msg.name.to_ascii_uppercase());
129        let since = msg.since;
130        let opcode = opcode as u16;
131        quote! {
132            /// The minimal object version supporting this request
133            pub const #since_cstname: u32 = #since;
134            /// The wire opcode for this request
135            pub const #opcode_cstname: u16 = #opcode;
136        }
137    });
138    let evt_constants = events.iter().enumerate().map(|(opcode, msg)| {
139        let since_cstname = format_ident!("EVT_{}_SINCE", msg.name.to_ascii_uppercase());
140        let opcode_cstname = format_ident!("EVT_{}_OPCODE", msg.name.to_ascii_uppercase());
141        let since = msg.since;
142        let opcode = opcode as u16;
143        quote! {
144            /// The minimal object version supporting this event
145            pub const #since_cstname: u32 = #since;
146            /// The wire opcode for this event
147            pub const #opcode_cstname: u16 = #opcode;
148        }
149    });
150
151    quote! {
152        #(#req_constants)*
153        #(#evt_constants)*
154    }
155}
156
157pub(crate) fn gen_message_enum(
158    name: &Ident,
159    side: Side,
160    receiver: bool,
161    messages: &[Message],
162) -> TokenStream {
163    let variants = messages
164        .iter()
165        .map(|msg| {
166            let mut docs = String::new();
167            if let Some((ref short, ref long)) = msg.description {
168                write!(docs, "{}\n\n{}\n", short, long.trim()).unwrap();
169            }
170            if let Some(Type::Destructor) = msg.typ {
171                write!(
172                    docs,
173                    "\nThis is a destructor, once {} this object cannot be used any longer.",
174                    if receiver { "received" } else { "sent" },
175                )
176                .unwrap()
177            }
178            if msg.since > 1 {
179                write!(docs, "\nOnly available since version {} of the interface", msg.since)
180                    .unwrap();
181            }
182
183            let doc_attr = to_doc_attr(&docs);
184            let msg_name = Ident::new(&snake_to_camel(&msg.name), Span::call_site());
185            let msg_variant_decl =
186                if msg.args.is_empty() {
187                    msg_name.into_token_stream()
188                } else {
189                    let fields = msg.args.iter().flat_map(|arg| {
190                let field_name =
191                    format_ident!("{}{}", if is_keyword(&arg.name) { "_" } else { "" }, arg.name);
192                let field_type_inner = if let Some(ref enu) = arg.enum_ {
193                    let enum_type = dotted_to_relname(enu);
194                    quote! { WEnum<#enum_type> }
195                } else {
196                    match arg.typ {
197                        Type::Uint => quote! { u32 },
198                        Type::Int => quote! { i32 },
199                        Type::Fixed => quote! { f64 },
200                        Type::String => quote! { String },
201                        Type::Array => quote! { Vec<u8> },
202                        Type::Fd => {
203                            if receiver {
204                                quote! { OwnedFd }
205                            } else {
206                                quote! { std::os::unix::io::BorrowedFd<'a> }
207                            }
208                        }
209                        Type::Object => {
210                            if let Some(ref iface) = arg.interface {
211                                let iface_mod = Ident::new(iface, Span::call_site());
212                                let iface_type =
213                                    Ident::new(&snake_to_camel(iface), Span::call_site());
214                                quote! { super::#iface_mod::#iface_type }
215                            } else if side == Side::Client {
216                                quote! { super::wayland_client::ObjectId }
217                            } else {
218                                quote! { super::wayland_server::ObjectId }
219                            }
220                        }
221                        Type::NewId if !receiver && side == Side::Client => {
222                            // Client-side sending does not have a pre-existing object
223                            // so skip serializing it
224                            if arg.interface.is_some() {
225                                return None;
226                            } else {
227                                quote! { (&'static Interface, u32) }
228                            }
229                        }
230                        Type::NewId => {
231                            if let Some(ref iface) = arg.interface {
232                                let iface_mod = Ident::new(iface, Span::call_site());
233                                let iface_type =
234                                    Ident::new(&snake_to_camel(iface), Span::call_site());
235                                if receiver && side == Side::Server {
236                                    quote! { New<super::#iface_mod::#iface_type> }
237                                } else {
238                                    quote! { super::#iface_mod::#iface_type }
239                                }
240                            } else {
241                                // bind-like function
242                                if side == Side::Client {
243                                    quote! { (String, u32, super::wayland_client::ObjectId) }
244                                } else {
245                                    quote! { (String, u32, super::wayland_server::ObjectId) }
246                                }
247                            }
248                        }
249                        Type::Destructor => panic!("An argument cannot have type \"destructor\"."),
250                    }
251                };
252
253                let field_type = if arg.allow_null {
254                    quote! { Option<#field_type_inner> }
255                } else {
256                    field_type_inner.into_token_stream()
257                };
258
259                let doc_attr = arg
260                    .description
261                    .as_ref()
262                    .map(description_to_doc_attr)
263                    .or_else(|| arg.summary.as_ref().map(|s| to_doc_attr(s)));
264
265                Some(quote! {
266                    #doc_attr
267                    #field_name: #field_type
268                })
269            });
270
271                    quote! {
272                        #msg_name {
273                            #(#fields,)*
274                        }
275                    }
276                };
277
278            quote! {
279                #doc_attr
280                #msg_variant_decl
281            }
282        })
283        .collect::<Vec<_>>();
284
285    let opcodes = messages.iter().enumerate().map(|(opcode, msg)| {
286        let msg_name = Ident::new(&snake_to_camel(&msg.name), Span::call_site());
287        let opcode = opcode as u16;
288        if msg.args.is_empty() {
289            quote! {
290                #name::#msg_name => #opcode
291            }
292        } else {
293            quote! {
294                #name::#msg_name { .. } => #opcode
295            }
296        }
297    });
298
299    // Placeholder to allow generic argument to be added later, without ABI
300    // break.
301    // TODO Use never type.
302    let (generic, phantom_variant, phantom_case) = if !receiver {
303        (
304            quote! { 'a },
305            quote! { #[doc(hidden)] __phantom_lifetime { phantom: std::marker::PhantomData<&'a ()>, never: std::convert::Infallible } },
306            quote! { #name::__phantom_lifetime { never, .. } => match never {} },
307        )
308    } else {
309        (quote! {}, quote! {}, quote! {})
310    };
311
312    quote! {
313        #[derive(Debug)]
314        #[non_exhaustive]
315        pub enum #name<#generic> {
316            #(#variants,)*
317            #phantom_variant
318        }
319
320        impl<#generic> #name<#generic> {
321            #[doc="Get the opcode number of this message"]
322            pub fn opcode(&self) -> u16 {
323                match *self {
324                    #(#opcodes,)*
325                    #phantom_case
326                }
327            }
328        }
329    }
330}
331
332pub(crate) fn gen_parse_body(interface: &Interface, side: Side) -> TokenStream {
333    let msgs = match side {
334        Side::Client => &interface.events,
335        Side::Server => &interface.requests,
336    };
337    let object_type = Ident::new(
338        match side {
339            Side::Client => "Proxy",
340            Side::Server => "Resource",
341        },
342        Span::call_site(),
343    );
344    let msg_type = Ident::new(
345        match side {
346            Side::Client => "Event",
347            Side::Server => "Request",
348        },
349        Span::call_site(),
350    );
351
352    let match_arms = msgs.iter().enumerate().map(|(opcode, msg)| {
353        let opcode = opcode as u16;
354        let msg_name = Ident::new(&snake_to_camel(&msg.name), Span::call_site());
355        let args_pat = msg.args.iter().map(|arg| {
356            let arg_name = Ident::new(
357                &format!("{}{}", if is_keyword(&arg.name) { "_" } else { "" }, arg.name),
358                Span::call_site(),
359            );
360            match arg.typ {
361                Type::Uint => quote!{ Some(Argument::Uint(#arg_name)) },
362                Type::Int => quote!{ Some(Argument::Int(#arg_name)) },
363                Type::String => quote!{ Some(Argument::Str(#arg_name)) },
364                Type::Fixed => quote!{ Some(Argument::Fixed(#arg_name)) },
365                Type::Array => quote!{ Some(Argument::Array(#arg_name)) },
366                Type::Object => quote!{ Some(Argument::Object(#arg_name)) },
367                Type::NewId => quote!{ Some(Argument::NewId(#arg_name)) },
368                Type::Fd => quote!{ Some(Argument::Fd(#arg_name)) },
369                Type::Destructor => panic!("Argument {}.{}.{} has type destructor ?!", interface.name, msg.name, arg.name),
370            }
371        });
372
373        let args_iter = msg.args.iter().map(|_| quote!{ arg_iter.next() });
374
375        let arg_names = msg.args.iter().map(|arg| {
376            let arg_name = format_ident!("{}{}", if is_keyword(&arg.name) { "_" } else { "" }, arg.name);
377            if arg.enum_.is_some() {
378                quote! { #arg_name: From::from(#arg_name as u32) }
379            } else {
380                match arg.typ {
381                    Type::Uint | Type::Int | Type::Fd => quote!{ #arg_name },
382                    Type::Fixed => quote!{ #arg_name: (#arg_name as f64) / 256.},
383                    Type::String => {
384                        if arg.allow_null {
385                            quote! {
386                                #arg_name: #arg_name.as_ref().map(|s| String::from_utf8_lossy(s.as_bytes()).into_owned())
387                            }
388                        } else {
389                            quote! {
390                                #arg_name: String::from_utf8_lossy(#arg_name.as_ref().unwrap().as_bytes()).into_owned()
391                            }
392                        }
393                    },
394                    Type::Object => {
395                        let create_proxy = if let Some(ref created_interface) = arg.interface {
396                            let created_iface_mod = Ident::new(created_interface, Span::call_site());
397                            let created_iface_type = Ident::new(&snake_to_camel(created_interface), Span::call_site());
398                            quote! {
399                                match <super::#created_iface_mod::#created_iface_type as #object_type>::from_id(conn, #arg_name.clone()) {
400                                    Ok(p) => p,
401                                    Err(_) => return Err(DispatchError::BadMessage {
402                                        sender_id: msg.sender_id,
403                                        interface: Self::interface().name,
404                                        opcode: msg.opcode
405                                    }),
406                                }
407                            }
408                        } else {
409                            quote! { #arg_name.clone() }
410                        };
411                        if arg.allow_null {
412                            quote! {
413                                #arg_name: if #arg_name.is_null() { None } else { Some(#create_proxy) }
414                            }
415                        } else {
416                            quote! {
417                                #arg_name: #create_proxy
418                            }
419                        }
420                    },
421                    Type::NewId => {
422                        let create_proxy = if let Some(ref created_interface) = arg.interface {
423                            let created_iface_mod = Ident::new(created_interface, Span::call_site());
424                            let created_iface_type = Ident::new(&snake_to_camel(created_interface), Span::call_site());
425                            quote! {
426                                match <super::#created_iface_mod::#created_iface_type as #object_type>::from_id(conn, #arg_name.clone()) {
427                                    Ok(p) => p,
428                                    Err(_) => return Err(DispatchError::BadMessage {
429                                        sender_id: msg.sender_id,
430                                        interface: Self::interface().name,
431                                        opcode: msg.opcode,
432                                    }),
433                                }
434                            }
435                        } else if side == Side::Server {
436                            quote! { New::wrap(#arg_name.clone()) }
437                        } else {
438                            quote! { #arg_name.clone() }
439                        };
440                        if arg.allow_null {
441                            if side == Side::Server {
442                                quote! {
443                                    #arg_name: if #arg_name.is_null() { None } else { Some(New::wrap(#create_proxy)) }
444                                }
445                            } else {
446                                quote! {
447                                    #arg_name: if #arg_name.is_null() { None } else { Some(#create_proxy) }
448                                }
449                            }
450                        } else if side == Side::Server {
451                            quote! {
452                                #arg_name: New::wrap(#create_proxy)
453                            }
454                        } else  {
455                            quote! {
456                                #arg_name: #create_proxy
457                            }
458                        }
459                    },
460                    Type::Array => {
461                        if arg.allow_null {
462                            quote! { if #arg_name.len() == 0 { None } else { Some(*#arg_name) } }
463                        } else {
464                            quote! { #arg_name: *#arg_name }
465                        }
466                    },
467                    Type::Destructor => unreachable!(),
468                }
469            }
470        });
471
472        quote! {
473            #opcode => {
474                if let (#(#args_pat),*) = (#(#args_iter),*) {
475                    Ok((me, #msg_type::#msg_name { #(#arg_names),* }))
476                } else {
477                    Err(DispatchError::BadMessage { sender_id: msg.sender_id, interface: Self::interface().name, opcode: msg.opcode })
478                }
479            }
480        }
481    });
482
483    quote! {
484        let me = Self::from_id(conn, msg.sender_id.clone()).unwrap();
485        let mut arg_iter = msg.args.into_iter();
486        match msg.opcode {
487            #(#match_arms),*
488            _ => Err(DispatchError::BadMessage { sender_id: msg.sender_id, interface: Self::interface().name, opcode: msg.opcode }),
489        }
490    }
491}
492
493pub(crate) fn gen_write_body(interface: &Interface, side: Side) -> TokenStream {
494    let msgs = match side {
495        Side::Client => &interface.requests,
496        Side::Server => &interface.events,
497    };
498    let msg_type = Ident::new(
499        match side {
500            Side::Client => "Request",
501            Side::Server => "Event",
502        },
503        Span::call_site(),
504    );
505    let arms = msgs.iter().enumerate().map(|(opcode, msg)| {
506        let msg_name = Ident::new(&snake_to_camel(&msg.name), Span::call_site());
507        let opcode = opcode as u16;
508        let arg_names = msg.args.iter().flat_map(|arg| {
509            if arg.typ == Type::NewId && arg.interface.is_some() && side == Side::Client {
510                None
511            } else {
512                Some(format_ident!("{}{}", if is_keyword(&arg.name) { "_" } else { "" }, arg.name))
513            }
514        });
515        let mut child_spec = None;
516        let args = msg.args.iter().flat_map(|arg| {
517            let arg_name = format_ident!("{}{}", if is_keyword(&arg.name) { "_" } else { "" }, arg.name);
518
519            match arg.typ {
520                Type::Int => vec![if arg.enum_.is_some() { quote!{ Argument::Int(Into::<u32>::into(#arg_name) as i32) } } else { quote!{ Argument::Int(#arg_name) } }],
521                Type::Uint => vec![if arg.enum_.is_some() { quote!{ Argument::Uint(#arg_name.into()) } } else { quote!{ Argument::Uint(#arg_name) } }],
522                Type::Fd => vec![quote!{ Argument::Fd(#arg_name) }],
523                Type::Fixed => vec![quote! { Argument::Fixed((#arg_name * 256.) as i32) }],
524                Type::Object => if arg.allow_null {
525                    if side == Side::Server {
526                        vec![quote! { if let Some(obj) = #arg_name { Argument::Object(Resource::id(&obj)) } else { Argument::Object(ObjectId::null()) } }]
527                    } else {
528                        vec![quote! { if let Some(obj) = #arg_name { Argument::Object(Proxy::id(&obj)) } else { Argument::Object(ObjectId::null()) } }]
529                    }
530                } else if side == Side::Server {
531                    vec![quote!{ Argument::Object(Resource::id(&#arg_name)) }]
532                } else {
533                    vec![quote!{ Argument::Object(Proxy::id(&#arg_name)) }]
534                },
535                Type::Array => if arg.allow_null {
536                    vec![quote! { if let Some(array) = #arg_name { Argument::Array(Box::new(array)) } else { Argument::Array(Box::new(Vec::new()))}}]
537                } else {
538                    vec![quote! { Argument::Array(Box::new(#arg_name)) }]
539                },
540                Type::String => if arg.allow_null {
541                    vec![quote! { Argument::Str(#arg_name.map(|s| Box::new(std::ffi::CString::new(s).unwrap()))) }]
542                } else {
543                    vec![quote! { Argument::Str(Some(Box::new(std::ffi::CString::new(#arg_name).unwrap()))) }]
544                },
545                Type::NewId => if side == Side::Client {
546                    if let Some(ref created_interface) = arg.interface {
547                        let created_iface_mod = Ident::new(created_interface, Span::call_site());
548                        let created_iface_type = Ident::new(&snake_to_camel(created_interface), Span::call_site());
549                        assert!(child_spec.is_none());
550                        child_spec = Some(quote! { {
551                            let my_info = conn.object_info(self.id())?;
552                            Some((super::#created_iface_mod::#created_iface_type::interface(), my_info.version))
553                        } });
554                        vec![quote! { Argument::NewId(ObjectId::null()) }]
555                    } else {
556                        assert!(child_spec.is_none());
557                        child_spec = Some(quote! {
558                            Some((#arg_name.0, #arg_name.1))
559                        });
560                        vec![
561                            quote! {
562                                Argument::Str(Some(Box::new(std::ffi::CString::new(#arg_name.0.name).unwrap())))
563                            },
564                            quote! {
565                                Argument::Uint(#arg_name.1)
566                            },
567                            quote! {
568                                Argument::NewId(ObjectId::null())
569                            },
570                        ]
571                    }
572                } else {
573                    // server-side NewId is the same as Object
574                    if arg.allow_null {
575                        vec![quote! { if let Some(obj) = #arg_name { Argument::NewId(Resource::id(&obj)) } else { Argument::NewId(ObjectId::null()) } }]
576                    } else {
577                        vec![quote!{ Argument::NewId(Resource::id(&#arg_name)) }]
578                    }
579                },
580                Type::Destructor => panic!("Argument {}.{}.{} has type destructor ?!", interface.name, msg.name, arg.name),
581            }
582        });
583        let args = if msg.args.is_empty() {
584            quote! {
585                smallvec::SmallVec::new()
586            }
587        } else if msg.args.len() <= 4 {
588            // Note: Keep in sync with `wayland_backend::protocol::INLINE_ARGS`.
589            // Fits in SmallVec inline capacity
590            quote! { {
591                let mut vec = smallvec::SmallVec::new();
592                #(
593                    vec.push(#args);
594                )*
595                vec
596            } }
597        } else {
598            quote! {
599                smallvec::SmallVec::from_vec(vec![#(#args),*])
600            }
601        };
602        if side == Side::Client {
603            let child_spec = child_spec.unwrap_or_else(|| quote! { None });
604            quote! {
605                #msg_type::#msg_name { #(#arg_names),* } => {
606                    let child_spec = #child_spec;
607                    let args = #args;
608                    Ok((Message {
609                        sender_id: self.id.clone(),
610                        opcode: #opcode,
611                        args
612                    }, child_spec))
613                }
614            }
615        } else {
616            quote! {
617                #msg_type::#msg_name { #(#arg_names),* } => Ok(Message {
618                    sender_id: self.id.clone(),
619                    opcode: #opcode,
620                    args: #args,
621                })
622            }
623        }
624    });
625    quote! {
626        match msg {
627            #(#arms,)*
628            #msg_type::__phantom_lifetime { never, .. } => match never {}
629        }
630    }
631}