palette/serde/
alpha_deserializer.rs

1use core::marker::PhantomData;
2
3use serde::{
4    de::{DeserializeSeed, MapAccess, Visitor},
5    Deserialize, Deserializer,
6};
7
8/// Deserializes a color with an attached alpha value. The alpha value is
9/// expected to be found alongside the other values in a flattened structure.
10pub(crate) struct AlphaDeserializer<'a, D, A> {
11    pub inner: D,
12    pub alpha: &'a mut Option<A>,
13}
14
15impl<'de, 'a, D, A> Deserializer<'de> for AlphaDeserializer<'a, D, A>
16where
17    D: Deserializer<'de>,
18    A: Deserialize<'de>,
19{
20    type Error = D::Error;
21
22    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
23    where
24        V: serde::de::Visitor<'de>,
25    {
26        self.inner.deserialize_seq(AlphaSeqVisitor {
27            inner: visitor,
28            alpha: self.alpha,
29        })
30    }
31
32    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
33    where
34        V: serde::de::Visitor<'de>,
35    {
36        self.inner.deserialize_tuple(
37            len + 1,
38            AlphaMapVisitor {
39                inner: visitor,
40                alpha: self.alpha,
41                field_count: Some(len),
42            },
43        )
44    }
45
46    fn deserialize_tuple_struct<V>(
47        self,
48        name: &'static str,
49        len: usize,
50        visitor: V,
51    ) -> Result<V::Value, Self::Error>
52    where
53        V: serde::de::Visitor<'de>,
54    {
55        self.inner.deserialize_tuple_struct(
56            name,
57            len + 1,
58            AlphaMapVisitor {
59                inner: visitor,
60                alpha: self.alpha,
61                field_count: Some(len),
62            },
63        )
64    }
65
66    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
67    where
68        V: serde::de::Visitor<'de>,
69    {
70        self.inner.deserialize_map(AlphaMapVisitor {
71            inner: visitor,
72            alpha: self.alpha,
73            field_count: None,
74        })
75    }
76
77    fn deserialize_struct<V>(
78        self,
79        name: &'static str,
80        fields: &'static [&'static str],
81        visitor: V,
82    ) -> Result<V::Value, Self::Error>
83    where
84        V: serde::de::Visitor<'de>,
85    {
86        self.inner.deserialize_struct(
87            name,
88            fields, // We can't add to the expected fields so we just hope it works anyway.
89            AlphaMapVisitor {
90                inner: visitor,
91                alpha: self.alpha,
92                field_count: Some(fields.len()),
93            },
94        )
95    }
96
97    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
98    where
99        V: serde::de::Visitor<'de>,
100    {
101        self.inner.deserialize_ignored_any(AlphaSeqVisitor {
102            inner: visitor,
103            alpha: self.alpha,
104        })
105    }
106
107    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
108    where
109        V: serde::de::Visitor<'de>,
110    {
111        self.inner.deserialize_tuple(
112            1,
113            AlphaMapVisitor {
114                inner: visitor,
115                alpha: self.alpha,
116                field_count: None,
117            },
118        )
119    }
120
121    fn deserialize_unit_struct<V>(
122        self,
123        name: &'static str,
124        visitor: V,
125    ) -> Result<V::Value, Self::Error>
126    where
127        V: serde::de::Visitor<'de>,
128    {
129        self.inner.deserialize_newtype_struct(
130            name,
131            AlphaMapVisitor {
132                inner: visitor,
133                alpha: self.alpha,
134                field_count: Some(0),
135            },
136        )
137    }
138
139    fn deserialize_newtype_struct<V>(
140        self,
141        name: &'static str,
142        visitor: V,
143    ) -> Result<V::Value, Self::Error>
144    where
145        V: serde::de::Visitor<'de>,
146    {
147        self.deserialize_tuple_struct(name, 1, visitor)
148    }
149
150    // Unsupported methods:
151
152    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
153    where
154        V: serde::de::Visitor<'de>,
155    {
156        alpha_deserializer_error()
157    }
158
159    fn deserialize_bool<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
160    where
161        V: serde::de::Visitor<'de>,
162    {
163        alpha_deserializer_error()
164    }
165
166    fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
167    where
168        V: serde::de::Visitor<'de>,
169    {
170        alpha_deserializer_error()
171    }
172
173    fn deserialize_i16<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
174    where
175        V: serde::de::Visitor<'de>,
176    {
177        alpha_deserializer_error()
178    }
179
180    fn deserialize_i32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
181    where
182        V: serde::de::Visitor<'de>,
183    {
184        alpha_deserializer_error()
185    }
186
187    fn deserialize_i64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
188    where
189        V: serde::de::Visitor<'de>,
190    {
191        alpha_deserializer_error()
192    }
193
194    fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
195    where
196        V: serde::de::Visitor<'de>,
197    {
198        alpha_deserializer_error()
199    }
200
201    fn deserialize_u16<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
202    where
203        V: serde::de::Visitor<'de>,
204    {
205        alpha_deserializer_error()
206    }
207
208    fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
209    where
210        V: serde::de::Visitor<'de>,
211    {
212        alpha_deserializer_error()
213    }
214
215    fn deserialize_u64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
216    where
217        V: serde::de::Visitor<'de>,
218    {
219        alpha_deserializer_error()
220    }
221
222    fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
223    where
224        V: serde::de::Visitor<'de>,
225    {
226        alpha_deserializer_error()
227    }
228
229    fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
230    where
231        V: serde::de::Visitor<'de>,
232    {
233        alpha_deserializer_error()
234    }
235
236    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
237    where
238        V: serde::de::Visitor<'de>,
239    {
240        alpha_deserializer_error()
241    }
242
243    fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
244    where
245        V: serde::de::Visitor<'de>,
246    {
247        alpha_deserializer_error()
248    }
249
250    fn deserialize_string<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
251    where
252        V: serde::de::Visitor<'de>,
253    {
254        alpha_deserializer_error()
255    }
256
257    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
258    where
259        V: serde::de::Visitor<'de>,
260    {
261        alpha_deserializer_error()
262    }
263
264    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
265    where
266        V: serde::de::Visitor<'de>,
267    {
268        alpha_deserializer_error()
269    }
270
271    fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
272    where
273        V: serde::de::Visitor<'de>,
274    {
275        alpha_deserializer_error()
276    }
277
278    fn deserialize_enum<V>(
279        self,
280        _name: &'static str,
281        _variants: &'static [&'static str],
282        _visitor: V,
283    ) -> Result<V::Value, Self::Error>
284    where
285        V: serde::de::Visitor<'de>,
286    {
287        alpha_deserializer_error()
288    }
289
290    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
291    where
292        V: serde::de::Visitor<'de>,
293    {
294        alpha_deserializer_error()
295    }
296}
297
298fn alpha_deserializer_error() -> ! {
299    unimplemented!("AlphaDeserializer can only deserialize structs, maps and sequences")
300}
301
302/// Deserializes a sequence with the alpha value last.
303struct AlphaSeqVisitor<'a, D, A> {
304    inner: D,
305    alpha: &'a mut Option<A>,
306}
307
308impl<'de, 'a, D, A> Visitor<'de> for AlphaSeqVisitor<'a, D, A>
309where
310    D: Visitor<'de>,
311    A: Deserialize<'de>,
312{
313    type Value = D::Value;
314
315    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
316        self.inner.expecting(formatter)?;
317        write!(formatter, " with an alpha value")
318    }
319
320    fn visit_seq<T>(self, mut seq: T) -> Result<Self::Value, T::Error>
321    where
322        T: serde::de::SeqAccess<'de>,
323    {
324        let color = self.inner.visit_seq(&mut seq)?;
325        *self.alpha = seq.next_element()?;
326
327        Ok(color)
328    }
329}
330
331/// Deserializes a map or a struct with an "alpha" key, or a tuple with the
332/// alpha value as the last value.
333struct AlphaMapVisitor<'a, D, A> {
334    inner: D,
335    alpha: &'a mut Option<A>,
336    field_count: Option<usize>,
337}
338
339impl<'de, 'a, D, A> Visitor<'de> for AlphaMapVisitor<'a, D, A>
340where
341    D: Visitor<'de>,
342    A: Deserialize<'de>,
343{
344    type Value = D::Value;
345
346    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
347        self.inner.expecting(formatter)?;
348        write!(formatter, " with an alpha value")
349    }
350
351    fn visit_seq<T>(self, mut seq: T) -> Result<Self::Value, T::Error>
352    where
353        T: serde::de::SeqAccess<'de>,
354    {
355        let color = if self.field_count.is_none() {
356            self.inner.visit_unit()?
357        } else {
358            self.inner.visit_seq(&mut seq)?
359        };
360        *self.alpha = seq.next_element()?;
361
362        Ok(color)
363    }
364
365    fn visit_map<T>(self, map: T) -> Result<Self::Value, T::Error>
366    where
367        T: serde::de::MapAccess<'de>,
368    {
369        self.inner.visit_map(MapWrapper {
370            inner: map,
371            alpha: self.alpha,
372            field_count: self.field_count,
373        })
374    }
375
376    fn visit_newtype_struct<T>(self, deserializer: T) -> Result<Self::Value, T::Error>
377    where
378        T: Deserializer<'de>,
379    {
380        *self.alpha = Some(A::deserialize(deserializer)?);
381        self.inner.visit_unit()
382    }
383}
384
385/// Intercepts map deserializing to catch the alpha value while deserializing
386/// the entries.
387struct MapWrapper<'a, T, A> {
388    inner: T,
389    alpha: &'a mut Option<A>,
390    field_count: Option<usize>,
391}
392
393impl<'a, 'de, T, A> MapAccess<'de> for MapWrapper<'a, T, A>
394where
395    T: MapAccess<'de>,
396    A: Deserialize<'de>,
397{
398    type Error = T::Error;
399
400    fn next_key_seed<K>(&mut self, mut seed: K) -> Result<Option<K::Value>, Self::Error>
401    where
402        K: serde::de::DeserializeSeed<'de>,
403    {
404        // Look for and extract the alpha value if its key is found, then return
405        // the next key after that. The first key that isn't alpha is
406        // immediately returned to the wrapped type's visitor.
407        loop {
408            seed = match self.inner.next_key_seed(AlphaFieldDeserializerSeed {
409                inner: seed,
410                field_count: self.field_count,
411            }) {
412                Ok(Some(AlphaField::Alpha(seed))) => {
413                    // We found the alpha value, so deserialize it...
414                    if self.alpha.is_some() {
415                        return Err(serde::de::Error::duplicate_field("alpha"));
416                    }
417                    *self.alpha = Some(self.inner.next_value()?);
418
419                    // ...then give the seed back for the next key
420                    seed
421                }
422                Ok(Some(AlphaField::Other(other))) => return Ok(Some(other)),
423                Ok(None) => return Ok(None),
424                Err(error) => return Err(error),
425            };
426        }
427    }
428
429    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
430    where
431        V: serde::de::DeserializeSeed<'de>,
432    {
433        self.inner.next_value_seed(seed)
434    }
435}
436
437struct AlphaFieldDeserializerSeed<T> {
438    inner: T,
439    field_count: Option<usize>,
440}
441
442impl<'de, T> DeserializeSeed<'de> for AlphaFieldDeserializerSeed<T>
443where
444    T: DeserializeSeed<'de>,
445{
446    type Value = AlphaField<T, T::Value>;
447
448    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
449    where
450        D: Deserializer<'de>,
451    {
452        deserializer.deserialize_identifier(AlphaFieldVisitor {
453            inner: self.inner,
454            field_count: self.field_count,
455        })
456    }
457}
458
459/// An alpha struct field or another struct field.
460enum AlphaField<A, O> {
461    Alpha(A),
462    Other(O),
463}
464
465/// A struct field name that hasn't been serialized yet.
466enum StructField<'de> {
467    Unsigned(u64),
468    Str(&'de str),
469    Bytes(&'de [u8]),
470}
471
472struct AlphaFieldVisitor<T> {
473    inner: T,
474    field_count: Option<usize>,
475}
476
477impl<'de, T> Visitor<'de> for AlphaFieldVisitor<T>
478where
479    T: DeserializeSeed<'de>,
480{
481    type Value = AlphaField<T, T::Value>;
482
483    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
484        write!(formatter, "alpha field")
485    }
486
487    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
488    where
489        E: serde::de::Error,
490    {
491        // We need the field count here to get the last tuple field. No field
492        // count implies that we definitely expected a struct or a map.
493        let field_count = self.field_count.ok_or_else(|| {
494            serde::de::Error::invalid_type(
495                serde::de::Unexpected::Unsigned(v),
496                &"map key or struct field",
497            )
498        })?;
499
500        // Assume that it's the alpha value if it's after the expected number of
501        // fields. Otherwise, pass on to the wrapped type's deserializer.
502        if v == field_count as u64 {
503            Ok(AlphaField::Alpha(self.inner))
504        } else {
505            Ok(AlphaField::Other(self.inner.deserialize(
506                StructFieldDeserializer {
507                    struct_field: StructField::Unsigned(v),
508                    error: PhantomData,
509                },
510            )?))
511        }
512    }
513
514    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
515    where
516        E: serde::de::Error,
517    {
518        // Assume that it's the alpha value if it's named "alpha". Otherwise,
519        // pass on to the wrapped type's deserializer.
520        if v == "alpha" {
521            Ok(AlphaField::Alpha(self.inner))
522        } else {
523            Ok(AlphaField::Other(self.inner.deserialize(
524                StructFieldDeserializer {
525                    struct_field: StructField::Str(v),
526                    error: PhantomData,
527                },
528            )?))
529        }
530    }
531
532    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
533    where
534        E: serde::de::Error,
535    {
536        // Assume that it's the alpha value if it's named "alpha". Otherwise,
537        // pass on to the wrapped type's deserializer.
538        if v == b"alpha" {
539            Ok(AlphaField::Alpha(self.inner))
540        } else {
541            Ok(AlphaField::Other(self.inner.deserialize(
542                StructFieldDeserializer {
543                    struct_field: StructField::Bytes(v),
544                    error: PhantomData,
545                },
546            )?))
547        }
548    }
549}
550
551/// Deserializes a non-alpha struct field name.
552struct StructFieldDeserializer<'a, E> {
553    struct_field: StructField<'a>,
554    error: PhantomData<fn() -> E>,
555}
556
557impl<'a, 'de, E> Deserializer<'de> for StructFieldDeserializer<'a, E>
558where
559    E: serde::de::Error,
560{
561    type Error = E;
562
563    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
564    where
565        V: Visitor<'de>,
566    {
567        match self.struct_field {
568            StructField::Unsigned(v) => visitor.visit_u64(v),
569            StructField::Str(v) => visitor.visit_str(v),
570            StructField::Bytes(v) => visitor.visit_bytes(v),
571        }
572    }
573
574    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
575    where
576        V: Visitor<'de>,
577    {
578        self.deserialize_identifier(visitor)
579    }
580
581    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
582    where
583        V: Visitor<'de>,
584    {
585        self.deserialize_identifier(visitor)
586    }
587
588    // Unsupported methods::
589
590    fn deserialize_bool<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
591    where
592        V: Visitor<'de>,
593    {
594        struct_field_deserializer_error()
595    }
596
597    fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
598    where
599        V: Visitor<'de>,
600    {
601        struct_field_deserializer_error()
602    }
603
604    fn deserialize_i16<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
605    where
606        V: Visitor<'de>,
607    {
608        struct_field_deserializer_error()
609    }
610
611    fn deserialize_i32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
612    where
613        V: Visitor<'de>,
614    {
615        struct_field_deserializer_error()
616    }
617
618    fn deserialize_i64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
619    where
620        V: Visitor<'de>,
621    {
622        struct_field_deserializer_error()
623    }
624
625    fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
626    where
627        V: Visitor<'de>,
628    {
629        struct_field_deserializer_error()
630    }
631
632    fn deserialize_u16<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
633    where
634        V: Visitor<'de>,
635    {
636        struct_field_deserializer_error()
637    }
638
639    fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
640    where
641        V: Visitor<'de>,
642    {
643        struct_field_deserializer_error()
644    }
645
646    fn deserialize_u64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
647    where
648        V: Visitor<'de>,
649    {
650        struct_field_deserializer_error()
651    }
652
653    fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
654    where
655        V: Visitor<'de>,
656    {
657        struct_field_deserializer_error()
658    }
659
660    fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
661    where
662        V: Visitor<'de>,
663    {
664        struct_field_deserializer_error()
665    }
666
667    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
668    where
669        V: Visitor<'de>,
670    {
671        struct_field_deserializer_error()
672    }
673
674    fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
675    where
676        V: Visitor<'de>,
677    {
678        struct_field_deserializer_error()
679    }
680
681    fn deserialize_string<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
682    where
683        V: Visitor<'de>,
684    {
685        struct_field_deserializer_error()
686    }
687
688    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
689    where
690        V: Visitor<'de>,
691    {
692        struct_field_deserializer_error()
693    }
694
695    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
696    where
697        V: Visitor<'de>,
698    {
699        struct_field_deserializer_error()
700    }
701
702    fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
703    where
704        V: Visitor<'de>,
705    {
706        struct_field_deserializer_error()
707    }
708
709    fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
710    where
711        V: Visitor<'de>,
712    {
713        struct_field_deserializer_error()
714    }
715
716    fn deserialize_unit_struct<V>(
717        self,
718        _name: &'static str,
719        _visitor: V,
720    ) -> Result<V::Value, Self::Error>
721    where
722        V: Visitor<'de>,
723    {
724        struct_field_deserializer_error()
725    }
726
727    fn deserialize_newtype_struct<V>(
728        self,
729        _name: &'static str,
730        _visitor: V,
731    ) -> Result<V::Value, Self::Error>
732    where
733        V: Visitor<'de>,
734    {
735        struct_field_deserializer_error()
736    }
737
738    fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
739    where
740        V: Visitor<'de>,
741    {
742        struct_field_deserializer_error()
743    }
744
745    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
746    where
747        V: Visitor<'de>,
748    {
749        struct_field_deserializer_error()
750    }
751
752    fn deserialize_tuple_struct<V>(
753        self,
754        _name: &'static str,
755        _len: usize,
756        _visitor: V,
757    ) -> Result<V::Value, Self::Error>
758    where
759        V: Visitor<'de>,
760    {
761        struct_field_deserializer_error()
762    }
763
764    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
765    where
766        V: Visitor<'de>,
767    {
768        struct_field_deserializer_error()
769    }
770
771    fn deserialize_struct<V>(
772        self,
773        _name: &'static str,
774        _fields: &'static [&'static str],
775        _visitor: V,
776    ) -> Result<V::Value, Self::Error>
777    where
778        V: Visitor<'de>,
779    {
780        struct_field_deserializer_error()
781    }
782
783    fn deserialize_enum<V>(
784        self,
785        _name: &'static str,
786        _variants: &'static [&'static str],
787        _visitor: V,
788    ) -> Result<V::Value, Self::Error>
789    where
790        V: Visitor<'de>,
791    {
792        struct_field_deserializer_error()
793    }
794}
795
796fn struct_field_deserializer_error() -> ! {
797    unimplemented!("StructFieldDeserializer can only deserialize identifiers")
798}