quick_xml/
utils.rs

1use std::borrow::{Borrow, Cow};
2use std::fmt::{self, Debug, Formatter};
3use std::io;
4use std::ops::Deref;
5
6#[cfg(feature = "async-tokio")]
7use std::{
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12#[cfg(feature = "serialize")]
13use serde::de::{Deserialize, Deserializer, Error, Visitor};
14#[cfg(feature = "serialize")]
15use serde::ser::{Serialize, Serializer};
16
17#[allow(clippy::ptr_arg)]
18pub fn write_cow_string(f: &mut Formatter, cow_string: &Cow<[u8]>) -> fmt::Result {
19    match cow_string {
20        Cow::Owned(s) => {
21            write!(f, "Owned(")?;
22            write_byte_string(f, s)?;
23        }
24        Cow::Borrowed(s) => {
25            write!(f, "Borrowed(")?;
26            write_byte_string(f, s)?;
27        }
28    }
29    write!(f, ")")
30}
31
32pub fn write_byte_string(f: &mut Formatter, byte_string: &[u8]) -> fmt::Result {
33    write!(f, "\"")?;
34    for b in byte_string {
35        match *b {
36            32..=33 | 35..=126 => write!(f, "{}", *b as char)?,
37            34 => write!(f, "\\\"")?,
38            _ => write!(f, "{:#02X}", b)?,
39        }
40    }
41    write!(f, "\"")?;
42    Ok(())
43}
44
45////////////////////////////////////////////////////////////////////////////////////////////////////
46
47/// A version of [`Cow`] that can borrow from two different buffers, one of them
48/// is a deserializer input.
49///
50/// # Lifetimes
51///
52/// - `'i`: lifetime of the data that deserializer borrow from the parsed input
53/// - `'s`: lifetime of the data that owned by a deserializer
54pub enum CowRef<'i, 's, B>
55where
56    B: ToOwned + ?Sized,
57{
58    /// An input borrowed from the parsed data
59    Input(&'i B),
60    /// An input borrowed from the buffer owned by another deserializer
61    Slice(&'s B),
62    /// An input taken from an external deserializer, owned by that deserializer
63    Owned(<B as ToOwned>::Owned),
64}
65impl<'i, 's, B> Deref for CowRef<'i, 's, B>
66where
67    B: ToOwned + ?Sized,
68    B::Owned: Borrow<B>,
69{
70    type Target = B;
71
72    fn deref(&self) -> &B {
73        match *self {
74            Self::Input(borrowed) => borrowed,
75            Self::Slice(borrowed) => borrowed,
76            Self::Owned(ref owned) => owned.borrow(),
77        }
78    }
79}
80
81impl<'i, 's, B> Debug for CowRef<'i, 's, B>
82where
83    B: ToOwned + ?Sized + Debug,
84    B::Owned: Debug,
85{
86    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
87        match *self {
88            Self::Input(borrowed) => Debug::fmt(borrowed, f),
89            Self::Slice(borrowed) => Debug::fmt(borrowed, f),
90            Self::Owned(ref owned) => Debug::fmt(owned, f),
91        }
92    }
93}
94
95impl<'i, 's> CowRef<'i, 's, str> {
96    /// Supply to the visitor a borrowed string, a string slice, or an owned
97    /// string depending on the kind of input. Unlike [`Self::deserialize_all`],
98    /// only part of [`Self::Owned`] string will be passed to the visitor.
99    ///
100    /// Calls
101    /// - `visitor.visit_borrowed_str` if data borrowed from the input
102    /// - `visitor.visit_str` if data borrowed from another source
103    /// - `visitor.visit_string` if data owned by this type
104    #[cfg(feature = "serialize")]
105    pub fn deserialize_str<V, E>(self, visitor: V) -> Result<V::Value, E>
106    where
107        V: Visitor<'i>,
108        E: Error,
109    {
110        match self {
111            Self::Input(s) => visitor.visit_borrowed_str(s),
112            Self::Slice(s) => visitor.visit_str(s),
113            Self::Owned(s) => visitor.visit_string(s),
114        }
115    }
116
117    /// Calls [`Visitor::visit_bool`] with `true` or `false` if text contains
118    /// [valid] boolean representation, otherwise calls [`Self::deserialize_str`].
119    ///
120    /// The valid boolean representations are only `"true"`, `"false"`, `"1"`, and `"0"`.
121    ///
122    /// [valid]: https://www.w3.org/TR/xmlschema11-2/#boolean
123    #[cfg(feature = "serialize")]
124    pub fn deserialize_bool<V, E>(self, visitor: V) -> Result<V::Value, E>
125    where
126        V: Visitor<'i>,
127        E: Error,
128    {
129        match self.as_ref() {
130            "1" | "true" => visitor.visit_bool(true),
131            "0" | "false" => visitor.visit_bool(false),
132            _ => self.deserialize_str(visitor),
133        }
134    }
135}
136
137////////////////////////////////////////////////////////////////////////////////////////////////////
138
139/// Wrapper around `Vec<u8>` that has a human-readable debug representation:
140/// printable ASCII symbols output as is, all other output in HEX notation.
141///
142/// Also, when [`serialize`] feature is on, this type deserialized using
143/// [`deserialize_byte_buf`](serde::Deserializer::deserialize_byte_buf) instead
144/// of vector's generic [`deserialize_seq`](serde::Deserializer::deserialize_seq)
145///
146/// [`serialize`]: ../index.html#serialize
147#[derive(PartialEq, Eq)]
148pub struct ByteBuf(pub Vec<u8>);
149
150impl Debug for ByteBuf {
151    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
152        write_byte_string(f, &self.0)
153    }
154}
155
156#[cfg(feature = "serialize")]
157impl<'de> Deserialize<'de> for ByteBuf {
158    fn deserialize<D>(d: D) -> Result<Self, D::Error>
159    where
160        D: Deserializer<'de>,
161    {
162        struct ValueVisitor;
163
164        impl<'de> Visitor<'de> for ValueVisitor {
165            type Value = ByteBuf;
166
167            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
168                f.write_str("byte data")
169            }
170
171            fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
172                Ok(ByteBuf(v.to_vec()))
173            }
174
175            fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
176                Ok(ByteBuf(v))
177            }
178        }
179
180        d.deserialize_byte_buf(ValueVisitor)
181    }
182}
183
184#[cfg(feature = "serialize")]
185impl Serialize for ByteBuf {
186    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
187    where
188        S: Serializer,
189    {
190        serializer.serialize_bytes(&self.0)
191    }
192}
193
194////////////////////////////////////////////////////////////////////////////////////////////////////
195
196/// Wrapper around `&[u8]` that has a human-readable debug representation:
197/// printable ASCII symbols output as is, all other output in HEX notation.
198///
199/// Also, when [`serialize`] feature is on, this type deserialized using
200/// [`deserialize_bytes`](serde::Deserializer::deserialize_bytes) instead
201/// of vector's generic [`deserialize_seq`](serde::Deserializer::deserialize_seq)
202///
203/// [`serialize`]: ../index.html#serialize
204#[derive(PartialEq, Eq)]
205pub struct Bytes<'de>(pub &'de [u8]);
206
207impl<'de> Debug for Bytes<'de> {
208    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
209        write_byte_string(f, self.0)
210    }
211}
212
213#[cfg(feature = "serialize")]
214impl<'de> Deserialize<'de> for Bytes<'de> {
215    fn deserialize<D>(d: D) -> Result<Self, D::Error>
216    where
217        D: Deserializer<'de>,
218    {
219        struct ValueVisitor;
220
221        impl<'de> Visitor<'de> for ValueVisitor {
222            type Value = Bytes<'de>;
223
224            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
225                f.write_str("borrowed bytes")
226            }
227
228            fn visit_borrowed_bytes<E: Error>(self, v: &'de [u8]) -> Result<Self::Value, E> {
229                Ok(Bytes(v))
230            }
231        }
232
233        d.deserialize_bytes(ValueVisitor)
234    }
235}
236
237#[cfg(feature = "serialize")]
238impl<'de> Serialize for Bytes<'de> {
239    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
240    where
241        S: Serializer,
242    {
243        serializer.serialize_bytes(self.0)
244    }
245}
246
247////////////////////////////////////////////////////////////////////////////////////////////////////
248
249/// A simple producer of infinite stream of bytes, useful in tests.
250///
251/// Will repeat `chunk` field indefinitely.
252pub struct Fountain<'a> {
253    /// That piece of data repeated infinitely...
254    pub chunk: &'a [u8],
255    /// Part of `chunk` that was consumed by BufRead impl
256    pub consumed: usize,
257    /// The overall count of read bytes
258    pub overall_read: u64,
259}
260
261impl<'a> io::Read for Fountain<'a> {
262    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
263        let available = &self.chunk[self.consumed..];
264        let len = buf.len().min(available.len());
265        let (portion, _) = available.split_at(len);
266
267        buf.copy_from_slice(portion);
268        Ok(len)
269    }
270}
271
272impl<'a> io::BufRead for Fountain<'a> {
273    #[inline]
274    fn fill_buf(&mut self) -> io::Result<&[u8]> {
275        Ok(&self.chunk[self.consumed..])
276    }
277
278    fn consume(&mut self, amt: usize) {
279        self.consumed += amt;
280        if self.consumed == self.chunk.len() {
281            self.consumed = 0;
282        }
283        self.overall_read += amt as u64;
284    }
285}
286
287#[cfg(feature = "async-tokio")]
288impl<'a> tokio::io::AsyncRead for Fountain<'a> {
289    fn poll_read(
290        self: Pin<&mut Self>,
291        _cx: &mut Context<'_>,
292        buf: &mut tokio::io::ReadBuf<'_>,
293    ) -> Poll<io::Result<()>> {
294        let available = &self.chunk[self.consumed..];
295        let len = buf.remaining().min(available.len());
296        let (portion, _) = available.split_at(len);
297
298        buf.put_slice(portion);
299        Poll::Ready(Ok(()))
300    }
301}
302
303#[cfg(feature = "async-tokio")]
304impl<'a> tokio::io::AsyncBufRead for Fountain<'a> {
305    #[inline]
306    fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
307        Poll::Ready(io::BufRead::fill_buf(self.get_mut()))
308    }
309
310    #[inline]
311    fn consume(self: Pin<&mut Self>, amt: usize) {
312        io::BufRead::consume(self.get_mut(), amt);
313    }
314}
315
316////////////////////////////////////////////////////////////////////////////////////////////////////
317
318/// A function to check whether the byte is a whitespace (blank, new line, carriage return or tab).
319#[inline]
320pub const fn is_whitespace(b: u8) -> bool {
321    matches!(b, b' ' | b'\r' | b'\n' | b'\t')
322}
323
324/// Calculates name from an element-like content. Name is the first word in `content`,
325/// where word boundaries is XML whitespace characters.
326///
327/// 'Whitespace' refers to the definition used by [`is_whitespace`].
328#[inline]
329pub const fn name_len(mut bytes: &[u8]) -> usize {
330    // Note: A pattern matching based approach (instead of indexing) allows
331    // making the function const.
332    let mut len = 0;
333    while let [first, rest @ ..] = bytes {
334        if is_whitespace(*first) {
335            break;
336        }
337        len += 1;
338        bytes = rest;
339    }
340    len
341}
342
343/// Returns a byte slice with leading XML whitespace bytes removed.
344///
345/// 'Whitespace' refers to the definition used by [`is_whitespace`].
346#[inline]
347pub const fn trim_xml_start(mut bytes: &[u8]) -> &[u8] {
348    // Note: A pattern matching based approach (instead of indexing) allows
349    // making the function const.
350    while let [first, rest @ ..] = bytes {
351        if is_whitespace(*first) {
352            bytes = rest;
353        } else {
354            break;
355        }
356    }
357    bytes
358}
359
360/// Returns a byte slice with trailing XML whitespace bytes removed.
361///
362/// 'Whitespace' refers to the definition used by [`is_whitespace`].
363#[inline]
364pub const fn trim_xml_end(mut bytes: &[u8]) -> &[u8] {
365    // Note: A pattern matching based approach (instead of indexing) allows
366    // making the function const.
367    while let [rest @ .., last] = bytes {
368        if is_whitespace(*last) {
369            bytes = rest;
370        } else {
371            break;
372        }
373    }
374    bytes
375}
376
377////////////////////////////////////////////////////////////////////////////////////////////////////
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use pretty_assertions::assert_eq;
383
384    #[test]
385    fn write_byte_string0() {
386        let bytes = ByteBuf(vec![10, 32, 32, 32, 32, 32, 32, 32, 32]);
387        assert_eq!(format!("{:?}", bytes), "\"0xA        \"");
388    }
389
390    #[test]
391    fn write_byte_string1() {
392        let bytes = ByteBuf(vec![
393            104, 116, 116, 112, 58, 47, 47, 119, 119, 119, 46, 119, 51, 46, 111, 114, 103, 47, 50,
394            48, 48, 50, 47, 48, 55, 47, 111, 119, 108, 35,
395        ]);
396        assert_eq!(
397            format!("{:?}", bytes),
398            r##""http://www.w3.org/2002/07/owl#""##
399        );
400    }
401
402    #[test]
403    fn write_byte_string3() {
404        let bytes = ByteBuf(vec![
405            67, 108, 97, 115, 115, 32, 73, 82, 73, 61, 34, 35, 66, 34,
406        ]);
407        assert_eq!(format!("{:?}", bytes), r##""Class IRI=\"#B\"""##);
408    }
409
410    #[test]
411    fn name_len() {
412        assert_eq!(super::name_len(b""), 0);
413        assert_eq!(super::name_len(b" abc"), 0);
414        assert_eq!(super::name_len(b" \t\r\n"), 0);
415
416        assert_eq!(super::name_len(b"abc"), 3);
417        assert_eq!(super::name_len(b"abc "), 3);
418
419        assert_eq!(super::name_len(b"a bc"), 1);
420        assert_eq!(super::name_len(b"ab\tc"), 2);
421        assert_eq!(super::name_len(b"ab\rc"), 2);
422        assert_eq!(super::name_len(b"ab\nc"), 2);
423    }
424
425    #[test]
426    fn trim_xml_start() {
427        assert_eq!(Bytes(super::trim_xml_start(b"")), Bytes(b""));
428        assert_eq!(Bytes(super::trim_xml_start(b"abc")), Bytes(b"abc"));
429        assert_eq!(
430            Bytes(super::trim_xml_start(b"\r\n\t ab \t\r\nc \t\r\n")),
431            Bytes(b"ab \t\r\nc \t\r\n")
432        );
433    }
434
435    #[test]
436    fn trim_xml_end() {
437        assert_eq!(Bytes(super::trim_xml_end(b"")), Bytes(b""));
438        assert_eq!(Bytes(super::trim_xml_end(b"abc")), Bytes(b"abc"));
439        assert_eq!(
440            Bytes(super::trim_xml_end(b"\r\n\t ab \t\r\nc \t\r\n")),
441            Bytes(b"\r\n\t ab \t\r\nc")
442        );
443    }
444}