zune_jpeg/
bitstream.rs

1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9#![allow(
10    clippy::if_not_else,
11    clippy::similar_names,
12    clippy::inline_always,
13    clippy::doc_markdown,
14    clippy::cast_sign_loss,
15    clippy::cast_possible_truncation
16)]
17
18//! This file exposes a single struct that can decode a huffman encoded
19//! Bitstream in a JPEG file
20//!
21//! This code is optimized for speed.
22//! It's meant to be super duper super fast, because everyone else depends on this being fast.
23//! It's (annoyingly) serial hence we cant use parallel bitstreams(it's variable length coding.)
24//!
25//! Furthermore, on the case of refills, we have to do bytewise processing because the standard decided
26//! that we want to support markers in the middle of streams(seriously few people use RST markers).
27//!
28//! So we pull in all optimization steps:
29//! - use `inline[always]`? ✅ ,
30//! - pre-execute most common cases ✅,
31//! - add random comments ✅
32//! -  fast paths ✅.
33//!
34//! Speed-wise: It is probably the fastest JPEG BitStream decoder to ever sail the seven seas because of
35//! a couple of optimization tricks.
36//! 1. Fast refills from libjpeg-turbo
37//! 2. As few as possible branches in decoder fast paths.
38//! 3. Accelerated AC table decoding borrowed from stb_image.h written by Fabian Gissen (@ rygorous),
39//! improved by me to handle more cases.
40//! 4. Safe and extensible routines(e.g. cool ways to eliminate bounds check)
41//! 5. No unsafe here
42//!
43//! Readability comes as a second priority(I tried with variable names this time, and we are wayy better than libjpeg).
44//!
45//! Anyway if you are reading this it means your cool and I hope you get whatever part of the code you are looking for
46//! (or learn something cool)
47//!
48//! Knock yourself out.
49use alloc::format;
50use alloc::string::ToString;
51use core::cmp::min;
52
53use zune_core::bytestream::{ZByteReader, ZReaderTrait};
54
55use crate::errors::DecodeErrors;
56use crate::huffman::{HuffmanTable, HUFF_LOOKAHEAD};
57use crate::marker::Marker;
58use crate::mcu::DCT_BLOCK;
59use crate::misc::UN_ZIGZAG;
60
61macro_rules! decode_huff {
62    ($stream:tt,$symbol:tt,$table:tt) => {
63        let mut code_length = $symbol >> HUFF_LOOKAHEAD;
64
65        ($symbol) &= (1 << HUFF_LOOKAHEAD) - 1;
66
67        if code_length > i32::from(HUFF_LOOKAHEAD)
68        {
69            // if the symbol cannot be resolved in the first HUFF_LOOKAHEAD bits,
70            // we know it lies somewhere between HUFF_LOOKAHEAD and 16 bits since jpeg imposes 16 bit
71            // limit, we can therefore look 16 bits ahead and try to resolve the symbol
72            // starting from 1+HUFF_LOOKAHEAD bits.
73            $symbol = ($stream).peek_bits::<16>() as i32;
74            // (Credits to Sean T. Barrett stb library for this optimization)
75            // maxcode is pre-shifted 16 bytes long so that it has (16-code_length)
76            // zeroes at the end hence we do not need to shift in the inner loop.
77            while code_length < 17{
78                if $symbol < $table.maxcode[code_length as usize]  {
79                    break;
80                }
81                code_length += 1;
82            }
83
84            if code_length == 17{
85                // symbol could not be decoded.
86                //
87                // We may think, lets fake zeroes, noo
88                // panic, because Huffman codes are sensitive, probably everything
89                // after this will be corrupt, so no need to continue.
90                return Err(DecodeErrors::Format(format!("Bad Huffman Code 0x{:X}, corrupt JPEG",$symbol)))
91            }
92
93            $symbol >>= (16-code_length);
94            ($symbol) = i32::from(
95                ($table).values
96                    [(($symbol + ($table).offset[code_length as usize]) & 0xFF) as usize],
97            );
98        }
99        // drop bits read
100        ($stream).drop_bits(code_length as u8);
101    };
102}
103
104/// A `BitStream` struct, a bit by bit reader with super powers
105///
106pub(crate) struct BitStream {
107    /// A MSB type buffer that is used for some certain operations
108    pub buffer:           u64,
109    /// A TOP  aligned MSB type buffer that is used to accelerate some operations like
110    /// peek_bits and get_bits.
111    ///
112    /// By top aligned, I mean the top bit (63) represents the top bit in the buffer.
113    aligned_buffer:       u64,
114    /// Tell us the bits left the two buffer
115    pub(crate) bits_left: u8,
116    /// Did we find a marker(RST/EOF) during decoding?
117    pub marker:           Option<Marker>,
118
119    /// Progressive decoding
120    pub successive_high: u8,
121    pub successive_low:  u8,
122    spec_start:          u8,
123    spec_end:            u8,
124    pub eob_run:         i32,
125    pub overread_by:     usize
126}
127
128impl BitStream {
129    /// Create a new BitStream
130    pub(crate) const fn new() -> BitStream {
131        BitStream {
132            buffer:          0,
133            aligned_buffer:  0,
134            bits_left:       0,
135            marker:          None,
136            successive_high: 0,
137            successive_low:  0,
138            spec_start:      0,
139            spec_end:        0,
140            eob_run:         0,
141            overread_by:     0
142        }
143    }
144
145    /// Create a new Bitstream for progressive decoding
146    #[allow(clippy::redundant_field_names)]
147    pub(crate) fn new_progressive(ah: u8, al: u8, spec_start: u8, spec_end: u8) -> BitStream {
148        BitStream {
149            buffer:          0,
150            aligned_buffer:  0,
151            bits_left:       0,
152            marker:          None,
153            successive_high: ah,
154            successive_low:  al,
155            spec_start:      spec_start,
156            spec_end:        spec_end,
157            eob_run:         0,
158            overread_by:     0
159        }
160    }
161
162    /// Refill the bit buffer by (a maximum of) 32 bits
163    ///
164    /// # Arguments
165    ///  - `reader`:`&mut BufReader<R>`: A mutable reference to an underlying
166    ///    File/Memory buffer containing a valid JPEG stream
167    ///
168    /// This function will only refill if `self.count` is less than 32
169    #[inline(always)] // to many call sites? ( perf improvement by 4%)
170    fn refill<T>(&mut self, reader: &mut ZByteReader<T>) -> Result<bool, DecodeErrors>
171    where
172        T: ZReaderTrait
173    {
174        /// Macro version of a single byte refill.
175        /// Arguments
176        /// buffer-> our io buffer, because rust macros cannot get values from
177        /// the surrounding environment bits_left-> number of bits left
178        /// to full refill
179        macro_rules! refill {
180            ($buffer:expr,$byte:expr,$bits_left:expr) => {
181                // read a byte from the stream
182                $byte = u64::from(reader.get_u8());
183                self.overread_by += usize::from(reader.eof());
184                // append to the buffer
185                // JPEG is a MSB type buffer so that means we append this
186                // to the lower end (0..8) of the buffer and push the rest bits above..
187                $buffer = ($buffer << 8) | $byte;
188                // Increment bits left
189                $bits_left += 8;
190                // Check for special case  of OxFF, to see if it's a stream or a marker
191                if $byte == 0xff {
192                    // read next byte
193                    let mut next_byte = u64::from(reader.get_u8());
194                    // Byte snuffing, if we encounter byte snuff, we skip the byte
195                    if next_byte != 0x00 {
196                        // skip that byte we read
197                        while next_byte == 0xFF {
198                            next_byte = u64::from(reader.get_u8());
199                        }
200
201                        if next_byte != 0x00 {
202                            // Undo the byte append and return
203                            $buffer >>= 8;
204                            $bits_left -= 8;
205
206                            if $bits_left != 0 {
207                                self.aligned_buffer = $buffer << (64 - $bits_left);
208                            }
209
210                            self.marker =
211                                Some(Marker::from_u8(next_byte as u8).ok_or_else(|| {
212                                    DecodeErrors::Format(format!(
213                                        "Unknown marker 0xFF{:X}",
214                                        next_byte
215                                    ))
216                                })?);
217                            return Ok(false);
218                        }
219                    }
220                }
221            };
222        }
223
224        // 32 bits is enough for a decode(16 bits) and receive_extend(max 16 bits)
225        // If we have less than 32 bits we refill
226        if self.bits_left < 32 && self.marker.is_none() {
227            // So before we do anything, check if we have a 0xFF byte
228
229            if reader.has(4) {
230                // we have 4 bytes to spare, read the 4 bytes into a temporary buffer
231                // create buffer
232                let msb_buf = reader.get_u32_be();
233                // check if we have 0xff
234                if !has_byte(msb_buf, 255) {
235                    self.bits_left += 32;
236                    self.buffer <<= 32;
237                    self.buffer |= u64::from(msb_buf);
238                    self.aligned_buffer = self.buffer << (64 - self.bits_left);
239                    return Ok(true);
240                }
241                // not there, rewind the read
242                reader.rewind(4);
243            }
244            // This serves two reasons,
245            // 1: Make clippy shut up
246            // 2: Favour register reuse
247            let mut byte;
248
249            // 4 refills, if all succeed the stream should contain enough bits to decode a
250            // value
251            refill!(self.buffer, byte, self.bits_left);
252            refill!(self.buffer, byte, self.bits_left);
253            refill!(self.buffer, byte, self.bits_left);
254            refill!(self.buffer, byte, self.bits_left);
255            // Construct an MSB buffer whose top bits are the bitstream we are currently holding.
256            self.aligned_buffer = self.buffer << (64 - self.bits_left);
257        }
258
259        return Ok(true);
260    }
261    /// Decode the DC coefficient in a MCU block.
262    ///
263    /// The decoded coefficient is written to `dc_prediction`
264    ///
265    #[allow(
266        clippy::cast_possible_truncation,
267        clippy::cast_sign_loss,
268        clippy::unwrap_used
269    )]
270    #[inline(always)]
271    fn decode_dc<T>(
272        &mut self, reader: &mut ZByteReader<T>, dc_table: &HuffmanTable, dc_prediction: &mut i32
273    ) -> Result<bool, DecodeErrors>
274    where
275        T: ZReaderTrait
276    {
277        let (mut symbol, r);
278
279        if self.bits_left < 32 {
280            self.refill(reader)?;
281        };
282        // look a head HUFF_LOOKAHEAD bits into the bitstream
283        symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
284        symbol = dc_table.lookup[symbol as usize];
285
286        decode_huff!(self, symbol, dc_table);
287
288        if symbol != 0 {
289            r = self.get_bits(symbol as u8);
290            symbol = huff_extend(r, symbol);
291        }
292        // Update DC prediction
293        *dc_prediction = dc_prediction.wrapping_add(symbol);
294
295        return Ok(true);
296    }
297
298    /// Decode a Minimum Code Unit(MCU) as quickly as possible
299    ///
300    /// # Arguments
301    /// - reader: The bitstream from where we read more bits.
302    /// - dc_table: The Huffman table used to decode the DC coefficient
303    /// - ac_table: The Huffman table used to decode AC values
304    /// - block: A memory region where we will write out the decoded values
305    /// - DC prediction: Last DC value for this component
306    ///
307    #[allow(
308        clippy::many_single_char_names,
309        clippy::cast_possible_truncation,
310        clippy::cast_sign_loss
311    )]
312    #[inline(never)]
313    pub fn decode_mcu_block<T>(
314        &mut self, reader: &mut ZByteReader<T>, dc_table: &HuffmanTable, ac_table: &HuffmanTable,
315        qt_table: &[i32; DCT_BLOCK], block: &mut [i32; 64], dc_prediction: &mut i32
316    ) -> Result<(), DecodeErrors>
317    where
318        T: ZReaderTrait
319    {
320        // Get fast AC table as a reference before we enter the hot path
321        let ac_lookup = ac_table.ac_lookup.as_ref().unwrap();
322
323        let (mut symbol, mut r, mut fast_ac);
324        // Decode AC coefficients
325        let mut pos: usize = 1;
326
327        // decode DC, dc prediction will contain the value
328        self.decode_dc(reader, dc_table, dc_prediction)?;
329
330        // set dc to be the dc prediction.
331        block[0] = *dc_prediction * qt_table[0];
332
333        while pos < 64 {
334            self.refill(reader)?;
335            symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
336            fast_ac = ac_lookup[symbol as usize];
337            symbol = ac_table.lookup[symbol as usize];
338
339            if fast_ac != 0 {
340                //  FAST AC path
341                pos += ((fast_ac >> 4) & 15) as usize; // run
342                let t_pos = UN_ZIGZAG[min(pos, 63)] & 63;
343
344                block[t_pos] = i32::from(fast_ac >> 8) * (qt_table[t_pos]); // Value
345                self.drop_bits((fast_ac & 15) as u8);
346                pos += 1;
347            } else {
348                decode_huff!(self, symbol, ac_table);
349
350                r = symbol >> 4;
351                symbol &= 15;
352
353                if symbol != 0 {
354                    pos += r as usize;
355                    r = self.get_bits(symbol as u8);
356                    symbol = huff_extend(r, symbol);
357                    let t_pos = UN_ZIGZAG[pos & 63] & 63;
358
359                    block[t_pos] = symbol * qt_table[t_pos];
360
361                    pos += 1;
362                } else if r != 15 {
363                    return Ok(());
364                } else {
365                    pos += 16;
366                }
367            }
368        }
369        return Ok(());
370    }
371
372    /// Peek `look_ahead` bits ahead without discarding them from the buffer
373    #[inline(always)]
374    #[allow(clippy::cast_possible_truncation)]
375    const fn peek_bits<const LOOKAHEAD: u8>(&self) -> i32 {
376        (self.aligned_buffer >> (64 - LOOKAHEAD)) as i32
377    }
378
379    /// Discard the next `N` bits without checking
380    #[inline]
381    fn drop_bits(&mut self, n: u8) {
382        self.bits_left = self.bits_left.saturating_sub(n);
383        self.aligned_buffer <<= n;
384    }
385
386    /// Read `n_bits` from the buffer  and discard them
387    #[inline(always)]
388    #[allow(clippy::cast_possible_truncation)]
389    fn get_bits(&mut self, n_bits: u8) -> i32 {
390        let mask = (1_u64 << n_bits) - 1;
391
392        self.aligned_buffer = self.aligned_buffer.rotate_left(u32::from(n_bits));
393        let bits = (self.aligned_buffer & mask) as i32;
394        self.bits_left = self.bits_left.wrapping_sub(n_bits);
395        bits
396    }
397
398    /// Decode a DC block
399    #[allow(clippy::cast_possible_truncation)]
400    #[inline]
401    pub(crate) fn decode_prog_dc_first<T>(
402        &mut self, reader: &mut ZByteReader<T>, dc_table: &HuffmanTable, block: &mut i16,
403        dc_prediction: &mut i32
404    ) -> Result<(), DecodeErrors>
405    where
406        T: ZReaderTrait
407    {
408        self.decode_dc(reader, dc_table, dc_prediction)?;
409        *block = (*dc_prediction as i16).wrapping_mul(1_i16 << self.successive_low);
410        return Ok(());
411    }
412    #[inline]
413    pub(crate) fn decode_prog_dc_refine<T>(
414        &mut self, reader: &mut ZByteReader<T>, block: &mut i16
415    ) -> Result<(), DecodeErrors>
416    where
417        T: ZReaderTrait
418    {
419        // refinement scan
420        if self.bits_left < 1 {
421            self.refill(reader)?;
422        }
423
424        if self.get_bit() == 1 {
425            *block = block.wrapping_add(1 << self.successive_low);
426        }
427
428        Ok(())
429    }
430
431    /// Get a single bit from the bitstream
432    fn get_bit(&mut self) -> u8 {
433        let k = (self.aligned_buffer >> 63) as u8;
434        // discard a bit
435        self.drop_bits(1);
436        return k;
437    }
438    pub(crate) fn decode_mcu_ac_first<T>(
439        &mut self, reader: &mut ZByteReader<T>, ac_table: &HuffmanTable, block: &mut [i16; 64]
440    ) -> Result<bool, DecodeErrors>
441    where
442        T: ZReaderTrait
443    {
444        let shift = self.successive_low;
445        let fast_ac = ac_table.ac_lookup.as_ref().unwrap();
446
447        let mut k = self.spec_start as usize;
448        let (mut symbol, mut r, mut fac);
449
450        // EOB runs are handled in mcu_prog.rs
451        'block: loop {
452            self.refill(reader)?;
453
454            symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
455            fac = fast_ac[symbol as usize];
456            symbol = ac_table.lookup[symbol as usize];
457
458            if fac != 0 {
459                // fast ac path
460                k += ((fac >> 4) & 15) as usize; // run
461                block[UN_ZIGZAG[min(k, 63)] & 63] = (fac >> 8).wrapping_mul(1 << shift); // value
462                self.drop_bits((fac & 15) as u8);
463                k += 1;
464            } else {
465                decode_huff!(self, symbol, ac_table);
466
467                r = symbol >> 4;
468                symbol &= 15;
469
470                if symbol != 0 {
471                    k += r as usize;
472                    r = self.get_bits(symbol as u8);
473                    symbol = huff_extend(r, symbol);
474                    block[UN_ZIGZAG[k & 63] & 63] = (symbol as i16).wrapping_mul(1 << shift);
475                    k += 1;
476                } else {
477                    if r != 15 {
478                        self.eob_run = 1 << r;
479                        self.eob_run += self.get_bits(r as u8);
480                        self.eob_run -= 1;
481                        break;
482                    }
483
484                    k += 16;
485                }
486            }
487
488            if k > self.spec_end as usize {
489                break 'block;
490            }
491        }
492        return Ok(true);
493    }
494    #[allow(clippy::too_many_lines, clippy::op_ref)]
495    pub(crate) fn decode_mcu_ac_refine<T>(
496        &mut self, reader: &mut ZByteReader<T>, table: &HuffmanTable, block: &mut [i16; 64]
497    ) -> Result<bool, DecodeErrors>
498    where
499        T: ZReaderTrait
500    {
501        let bit = (1 << self.successive_low) as i16;
502
503        let mut k = self.spec_start;
504        let (mut symbol, mut r);
505
506        if self.eob_run == 0 {
507            'no_eob: loop {
508                // Decode a coefficient from the bit stream
509                self.refill(reader)?;
510
511                symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
512                symbol = table.lookup[symbol as usize];
513
514                decode_huff!(self, symbol, table);
515
516                r = symbol >> 4;
517                symbol &= 15;
518
519                if symbol == 0 {
520                    if r != 15 {
521                        // EOB run is 2^r + bits
522                        self.eob_run = 1 << r;
523                        self.eob_run += self.get_bits(r as u8);
524                        // EOB runs are handled by the eob logic
525                        break 'no_eob;
526                    }
527                } else {
528                    if symbol != 1 {
529                        return Err(DecodeErrors::HuffmanDecode(
530                            "Bad Huffman code, corrupt JPEG?".to_string()
531                        ));
532                    }
533                    // get sign bit
534                    // We assume we have enough bits, which should be correct for sane images
535                    // since we refill by 32 above
536                    if self.get_bit() == 1 {
537                        symbol = i32::from(bit);
538                    } else {
539                        symbol = i32::from(-bit);
540                    }
541                }
542
543                // Advance over already nonzero coefficients  appending
544                // correction bits to the non-zeroes.
545                // A correction bit is 1 if the absolute value of the coefficient must be increased
546
547                if k <= self.spec_end {
548                    'advance_nonzero: loop {
549                        let coefficient = &mut block[UN_ZIGZAG[k as usize & 63] & 63];
550
551                        if *coefficient != 0 {
552                            if self.get_bit() == 1 && (*coefficient & bit) == 0 {
553                                if *coefficient >= 0 {
554                                    *coefficient += bit;
555                                } else {
556                                    *coefficient -= bit;
557                                }
558                            }
559
560                            if self.bits_left < 1 {
561                                self.refill(reader)?;
562                            }
563                        } else {
564                            r -= 1;
565
566                            if r < 0 {
567                                // reached target zero coefficient.
568                                break 'advance_nonzero;
569                            }
570                        };
571
572                        if k == self.spec_end {
573                            break 'advance_nonzero;
574                        }
575
576                        k += 1;
577                    }
578                }
579
580                if symbol != 0 {
581                    let pos = UN_ZIGZAG[k as usize & 63];
582                    // output new non-zero coefficient.
583                    block[pos & 63] = symbol as i16;
584                }
585
586                k += 1;
587
588                if k > self.spec_end {
589                    break 'no_eob;
590                }
591            }
592        }
593        if self.eob_run > 0 {
594            // only run if block does not consists of purely zeroes
595            if &block[1..] != &[0; 63] {
596                self.refill(reader)?;
597
598                while k <= self.spec_end {
599                    let coefficient = &mut block[UN_ZIGZAG[k as usize & 63] & 63];
600
601                    if *coefficient != 0 && self.get_bit() == 1 {
602                        // check if we already modified it, if so do nothing, otherwise
603                        // append the correction bit.
604                        if (*coefficient & bit) == 0 {
605                            if *coefficient >= 0 {
606                                *coefficient = coefficient.wrapping_add(bit);
607                            } else {
608                                *coefficient = coefficient.wrapping_sub(bit);
609                            }
610                        }
611                    }
612                    if self.bits_left < 1 {
613                        // refill at the last possible moment
614                        self.refill(reader)?;
615                    }
616                    k += 1;
617                }
618            }
619            // count a block completed in EOB run
620            self.eob_run -= 1;
621        }
622        return Ok(true);
623    }
624
625    pub fn update_progressive_params(&mut self, ah: u8, al: u8, spec_start: u8, spec_end: u8) {
626        self.successive_high = ah;
627        self.successive_low = al;
628        self.spec_start = spec_start;
629        self.spec_end = spec_end;
630    }
631
632    /// Reset the stream if we have a restart marker
633    ///
634    /// Restart markers indicate drop those bits in the stream and zero out
635    /// everything
636    #[cold]
637    pub fn reset(&mut self) {
638        self.bits_left = 0;
639        self.marker = None;
640        self.buffer = 0;
641        self.aligned_buffer = 0;
642        self.eob_run = 0;
643    }
644}
645
646/// Do the equivalent of JPEG HUFF_EXTEND
647#[inline(always)]
648fn huff_extend(x: i32, s: i32) -> i32 {
649    // if x<s return x else return x+offset[s] where offset[s] = ( (-1<<s)+1)
650    (x) + ((((x) - (1 << ((s) - 1))) >> 31) & (((-1) << (s)) + 1))
651}
652
653fn has_zero(v: u32) -> bool {
654    // Retrieved from Stanford bithacks
655    // @ https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
656    return !((((v & 0x7F7F_7F7F) + 0x7F7F_7F7F) | v) | 0x7F7F_7F7F) != 0;
657}
658
659fn has_byte(b: u32, val: u8) -> bool {
660    // Retrieved from Stanford bithacks
661    // @ https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
662    has_zero(b ^ ((!0_u32 / 255) * u32::from(val)))
663}