zune_jpeg/bitstream.rs
1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9#![allow(
10 clippy::if_not_else,
11 clippy::similar_names,
12 clippy::inline_always,
13 clippy::doc_markdown,
14 clippy::cast_sign_loss,
15 clippy::cast_possible_truncation
16)]
17
18//! This file exposes a single struct that can decode a huffman encoded
19//! Bitstream in a JPEG file
20//!
21//! This code is optimized for speed.
22//! It's meant to be super duper super fast, because everyone else depends on this being fast.
23//! It's (annoyingly) serial hence we cant use parallel bitstreams(it's variable length coding.)
24//!
25//! Furthermore, on the case of refills, we have to do bytewise processing because the standard decided
26//! that we want to support markers in the middle of streams(seriously few people use RST markers).
27//!
28//! So we pull in all optimization steps:
29//! - use `inline[always]`? ✅ ,
30//! - pre-execute most common cases ✅,
31//! - add random comments ✅
32//! - fast paths ✅.
33//!
34//! Speed-wise: It is probably the fastest JPEG BitStream decoder to ever sail the seven seas because of
35//! a couple of optimization tricks.
36//! 1. Fast refills from libjpeg-turbo
37//! 2. As few as possible branches in decoder fast paths.
38//! 3. Accelerated AC table decoding borrowed from stb_image.h written by Fabian Gissen (@ rygorous),
39//! improved by me to handle more cases.
40//! 4. Safe and extensible routines(e.g. cool ways to eliminate bounds check)
41//! 5. No unsafe here
42//!
43//! Readability comes as a second priority(I tried with variable names this time, and we are wayy better than libjpeg).
44//!
45//! Anyway if you are reading this it means your cool and I hope you get whatever part of the code you are looking for
46//! (or learn something cool)
47//!
48//! Knock yourself out.
49use alloc::format;
50use alloc::string::ToString;
51use core::cmp::min;
52
53use zune_core::bytestream::{ZByteReader, ZReaderTrait};
54
55use crate::errors::DecodeErrors;
56use crate::huffman::{HuffmanTable, HUFF_LOOKAHEAD};
57use crate::marker::Marker;
58use crate::mcu::DCT_BLOCK;
59use crate::misc::UN_ZIGZAG;
60
61macro_rules! decode_huff {
62 ($stream:tt,$symbol:tt,$table:tt) => {
63 let mut code_length = $symbol >> HUFF_LOOKAHEAD;
64
65 ($symbol) &= (1 << HUFF_LOOKAHEAD) - 1;
66
67 if code_length > i32::from(HUFF_LOOKAHEAD)
68 {
69 // if the symbol cannot be resolved in the first HUFF_LOOKAHEAD bits,
70 // we know it lies somewhere between HUFF_LOOKAHEAD and 16 bits since jpeg imposes 16 bit
71 // limit, we can therefore look 16 bits ahead and try to resolve the symbol
72 // starting from 1+HUFF_LOOKAHEAD bits.
73 $symbol = ($stream).peek_bits::<16>() as i32;
74 // (Credits to Sean T. Barrett stb library for this optimization)
75 // maxcode is pre-shifted 16 bytes long so that it has (16-code_length)
76 // zeroes at the end hence we do not need to shift in the inner loop.
77 while code_length < 17{
78 if $symbol < $table.maxcode[code_length as usize] {
79 break;
80 }
81 code_length += 1;
82 }
83
84 if code_length == 17{
85 // symbol could not be decoded.
86 //
87 // We may think, lets fake zeroes, noo
88 // panic, because Huffman codes are sensitive, probably everything
89 // after this will be corrupt, so no need to continue.
90 return Err(DecodeErrors::Format(format!("Bad Huffman Code 0x{:X}, corrupt JPEG",$symbol)))
91 }
92
93 $symbol >>= (16-code_length);
94 ($symbol) = i32::from(
95 ($table).values
96 [(($symbol + ($table).offset[code_length as usize]) & 0xFF) as usize],
97 );
98 }
99 // drop bits read
100 ($stream).drop_bits(code_length as u8);
101 };
102}
103
104/// A `BitStream` struct, a bit by bit reader with super powers
105///
106pub(crate) struct BitStream {
107 /// A MSB type buffer that is used for some certain operations
108 pub buffer: u64,
109 /// A TOP aligned MSB type buffer that is used to accelerate some operations like
110 /// peek_bits and get_bits.
111 ///
112 /// By top aligned, I mean the top bit (63) represents the top bit in the buffer.
113 aligned_buffer: u64,
114 /// Tell us the bits left the two buffer
115 pub(crate) bits_left: u8,
116 /// Did we find a marker(RST/EOF) during decoding?
117 pub marker: Option<Marker>,
118
119 /// Progressive decoding
120 pub successive_high: u8,
121 pub successive_low: u8,
122 spec_start: u8,
123 spec_end: u8,
124 pub eob_run: i32,
125 pub overread_by: usize
126}
127
128impl BitStream {
129 /// Create a new BitStream
130 pub(crate) const fn new() -> BitStream {
131 BitStream {
132 buffer: 0,
133 aligned_buffer: 0,
134 bits_left: 0,
135 marker: None,
136 successive_high: 0,
137 successive_low: 0,
138 spec_start: 0,
139 spec_end: 0,
140 eob_run: 0,
141 overread_by: 0
142 }
143 }
144
145 /// Create a new Bitstream for progressive decoding
146 #[allow(clippy::redundant_field_names)]
147 pub(crate) fn new_progressive(ah: u8, al: u8, spec_start: u8, spec_end: u8) -> BitStream {
148 BitStream {
149 buffer: 0,
150 aligned_buffer: 0,
151 bits_left: 0,
152 marker: None,
153 successive_high: ah,
154 successive_low: al,
155 spec_start: spec_start,
156 spec_end: spec_end,
157 eob_run: 0,
158 overread_by: 0
159 }
160 }
161
162 /// Refill the bit buffer by (a maximum of) 32 bits
163 ///
164 /// # Arguments
165 /// - `reader`:`&mut BufReader<R>`: A mutable reference to an underlying
166 /// File/Memory buffer containing a valid JPEG stream
167 ///
168 /// This function will only refill if `self.count` is less than 32
169 #[inline(always)] // to many call sites? ( perf improvement by 4%)
170 fn refill<T>(&mut self, reader: &mut ZByteReader<T>) -> Result<bool, DecodeErrors>
171 where
172 T: ZReaderTrait
173 {
174 /// Macro version of a single byte refill.
175 /// Arguments
176 /// buffer-> our io buffer, because rust macros cannot get values from
177 /// the surrounding environment bits_left-> number of bits left
178 /// to full refill
179 macro_rules! refill {
180 ($buffer:expr,$byte:expr,$bits_left:expr) => {
181 // read a byte from the stream
182 $byte = u64::from(reader.get_u8());
183 self.overread_by += usize::from(reader.eof());
184 // append to the buffer
185 // JPEG is a MSB type buffer so that means we append this
186 // to the lower end (0..8) of the buffer and push the rest bits above..
187 $buffer = ($buffer << 8) | $byte;
188 // Increment bits left
189 $bits_left += 8;
190 // Check for special case of OxFF, to see if it's a stream or a marker
191 if $byte == 0xff {
192 // read next byte
193 let mut next_byte = u64::from(reader.get_u8());
194 // Byte snuffing, if we encounter byte snuff, we skip the byte
195 if next_byte != 0x00 {
196 // skip that byte we read
197 while next_byte == 0xFF {
198 next_byte = u64::from(reader.get_u8());
199 }
200
201 if next_byte != 0x00 {
202 // Undo the byte append and return
203 $buffer >>= 8;
204 $bits_left -= 8;
205
206 if $bits_left != 0 {
207 self.aligned_buffer = $buffer << (64 - $bits_left);
208 }
209
210 self.marker =
211 Some(Marker::from_u8(next_byte as u8).ok_or_else(|| {
212 DecodeErrors::Format(format!(
213 "Unknown marker 0xFF{:X}",
214 next_byte
215 ))
216 })?);
217 return Ok(false);
218 }
219 }
220 }
221 };
222 }
223
224 // 32 bits is enough for a decode(16 bits) and receive_extend(max 16 bits)
225 // If we have less than 32 bits we refill
226 if self.bits_left < 32 && self.marker.is_none() {
227 // So before we do anything, check if we have a 0xFF byte
228
229 if reader.has(4) {
230 // we have 4 bytes to spare, read the 4 bytes into a temporary buffer
231 // create buffer
232 let msb_buf = reader.get_u32_be();
233 // check if we have 0xff
234 if !has_byte(msb_buf, 255) {
235 self.bits_left += 32;
236 self.buffer <<= 32;
237 self.buffer |= u64::from(msb_buf);
238 self.aligned_buffer = self.buffer << (64 - self.bits_left);
239 return Ok(true);
240 }
241 // not there, rewind the read
242 reader.rewind(4);
243 }
244 // This serves two reasons,
245 // 1: Make clippy shut up
246 // 2: Favour register reuse
247 let mut byte;
248
249 // 4 refills, if all succeed the stream should contain enough bits to decode a
250 // value
251 refill!(self.buffer, byte, self.bits_left);
252 refill!(self.buffer, byte, self.bits_left);
253 refill!(self.buffer, byte, self.bits_left);
254 refill!(self.buffer, byte, self.bits_left);
255 // Construct an MSB buffer whose top bits are the bitstream we are currently holding.
256 self.aligned_buffer = self.buffer << (64 - self.bits_left);
257 }
258
259 return Ok(true);
260 }
261 /// Decode the DC coefficient in a MCU block.
262 ///
263 /// The decoded coefficient is written to `dc_prediction`
264 ///
265 #[allow(
266 clippy::cast_possible_truncation,
267 clippy::cast_sign_loss,
268 clippy::unwrap_used
269 )]
270 #[inline(always)]
271 fn decode_dc<T>(
272 &mut self, reader: &mut ZByteReader<T>, dc_table: &HuffmanTable, dc_prediction: &mut i32
273 ) -> Result<bool, DecodeErrors>
274 where
275 T: ZReaderTrait
276 {
277 let (mut symbol, r);
278
279 if self.bits_left < 32 {
280 self.refill(reader)?;
281 };
282 // look a head HUFF_LOOKAHEAD bits into the bitstream
283 symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
284 symbol = dc_table.lookup[symbol as usize];
285
286 decode_huff!(self, symbol, dc_table);
287
288 if symbol != 0 {
289 r = self.get_bits(symbol as u8);
290 symbol = huff_extend(r, symbol);
291 }
292 // Update DC prediction
293 *dc_prediction = dc_prediction.wrapping_add(symbol);
294
295 return Ok(true);
296 }
297
298 /// Decode a Minimum Code Unit(MCU) as quickly as possible
299 ///
300 /// # Arguments
301 /// - reader: The bitstream from where we read more bits.
302 /// - dc_table: The Huffman table used to decode the DC coefficient
303 /// - ac_table: The Huffman table used to decode AC values
304 /// - block: A memory region where we will write out the decoded values
305 /// - DC prediction: Last DC value for this component
306 ///
307 #[allow(
308 clippy::many_single_char_names,
309 clippy::cast_possible_truncation,
310 clippy::cast_sign_loss
311 )]
312 #[inline(never)]
313 pub fn decode_mcu_block<T>(
314 &mut self, reader: &mut ZByteReader<T>, dc_table: &HuffmanTable, ac_table: &HuffmanTable,
315 qt_table: &[i32; DCT_BLOCK], block: &mut [i32; 64], dc_prediction: &mut i32
316 ) -> Result<(), DecodeErrors>
317 where
318 T: ZReaderTrait
319 {
320 // Get fast AC table as a reference before we enter the hot path
321 let ac_lookup = ac_table.ac_lookup.as_ref().unwrap();
322
323 let (mut symbol, mut r, mut fast_ac);
324 // Decode AC coefficients
325 let mut pos: usize = 1;
326
327 // decode DC, dc prediction will contain the value
328 self.decode_dc(reader, dc_table, dc_prediction)?;
329
330 // set dc to be the dc prediction.
331 block[0] = *dc_prediction * qt_table[0];
332
333 while pos < 64 {
334 self.refill(reader)?;
335 symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
336 fast_ac = ac_lookup[symbol as usize];
337 symbol = ac_table.lookup[symbol as usize];
338
339 if fast_ac != 0 {
340 // FAST AC path
341 pos += ((fast_ac >> 4) & 15) as usize; // run
342 let t_pos = UN_ZIGZAG[min(pos, 63)] & 63;
343
344 block[t_pos] = i32::from(fast_ac >> 8) * (qt_table[t_pos]); // Value
345 self.drop_bits((fast_ac & 15) as u8);
346 pos += 1;
347 } else {
348 decode_huff!(self, symbol, ac_table);
349
350 r = symbol >> 4;
351 symbol &= 15;
352
353 if symbol != 0 {
354 pos += r as usize;
355 r = self.get_bits(symbol as u8);
356 symbol = huff_extend(r, symbol);
357 let t_pos = UN_ZIGZAG[pos & 63] & 63;
358
359 block[t_pos] = symbol * qt_table[t_pos];
360
361 pos += 1;
362 } else if r != 15 {
363 return Ok(());
364 } else {
365 pos += 16;
366 }
367 }
368 }
369 return Ok(());
370 }
371
372 /// Peek `look_ahead` bits ahead without discarding them from the buffer
373 #[inline(always)]
374 #[allow(clippy::cast_possible_truncation)]
375 const fn peek_bits<const LOOKAHEAD: u8>(&self) -> i32 {
376 (self.aligned_buffer >> (64 - LOOKAHEAD)) as i32
377 }
378
379 /// Discard the next `N` bits without checking
380 #[inline]
381 fn drop_bits(&mut self, n: u8) {
382 self.bits_left = self.bits_left.saturating_sub(n);
383 self.aligned_buffer <<= n;
384 }
385
386 /// Read `n_bits` from the buffer and discard them
387 #[inline(always)]
388 #[allow(clippy::cast_possible_truncation)]
389 fn get_bits(&mut self, n_bits: u8) -> i32 {
390 let mask = (1_u64 << n_bits) - 1;
391
392 self.aligned_buffer = self.aligned_buffer.rotate_left(u32::from(n_bits));
393 let bits = (self.aligned_buffer & mask) as i32;
394 self.bits_left = self.bits_left.wrapping_sub(n_bits);
395 bits
396 }
397
398 /// Decode a DC block
399 #[allow(clippy::cast_possible_truncation)]
400 #[inline]
401 pub(crate) fn decode_prog_dc_first<T>(
402 &mut self, reader: &mut ZByteReader<T>, dc_table: &HuffmanTable, block: &mut i16,
403 dc_prediction: &mut i32
404 ) -> Result<(), DecodeErrors>
405 where
406 T: ZReaderTrait
407 {
408 self.decode_dc(reader, dc_table, dc_prediction)?;
409 *block = (*dc_prediction as i16).wrapping_mul(1_i16 << self.successive_low);
410 return Ok(());
411 }
412 #[inline]
413 pub(crate) fn decode_prog_dc_refine<T>(
414 &mut self, reader: &mut ZByteReader<T>, block: &mut i16
415 ) -> Result<(), DecodeErrors>
416 where
417 T: ZReaderTrait
418 {
419 // refinement scan
420 if self.bits_left < 1 {
421 self.refill(reader)?;
422 }
423
424 if self.get_bit() == 1 {
425 *block = block.wrapping_add(1 << self.successive_low);
426 }
427
428 Ok(())
429 }
430
431 /// Get a single bit from the bitstream
432 fn get_bit(&mut self) -> u8 {
433 let k = (self.aligned_buffer >> 63) as u8;
434 // discard a bit
435 self.drop_bits(1);
436 return k;
437 }
438 pub(crate) fn decode_mcu_ac_first<T>(
439 &mut self, reader: &mut ZByteReader<T>, ac_table: &HuffmanTable, block: &mut [i16; 64]
440 ) -> Result<bool, DecodeErrors>
441 where
442 T: ZReaderTrait
443 {
444 let shift = self.successive_low;
445 let fast_ac = ac_table.ac_lookup.as_ref().unwrap();
446
447 let mut k = self.spec_start as usize;
448 let (mut symbol, mut r, mut fac);
449
450 // EOB runs are handled in mcu_prog.rs
451 'block: loop {
452 self.refill(reader)?;
453
454 symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
455 fac = fast_ac[symbol as usize];
456 symbol = ac_table.lookup[symbol as usize];
457
458 if fac != 0 {
459 // fast ac path
460 k += ((fac >> 4) & 15) as usize; // run
461 block[UN_ZIGZAG[min(k, 63)] & 63] = (fac >> 8).wrapping_mul(1 << shift); // value
462 self.drop_bits((fac & 15) as u8);
463 k += 1;
464 } else {
465 decode_huff!(self, symbol, ac_table);
466
467 r = symbol >> 4;
468 symbol &= 15;
469
470 if symbol != 0 {
471 k += r as usize;
472 r = self.get_bits(symbol as u8);
473 symbol = huff_extend(r, symbol);
474 block[UN_ZIGZAG[k & 63] & 63] = (symbol as i16).wrapping_mul(1 << shift);
475 k += 1;
476 } else {
477 if r != 15 {
478 self.eob_run = 1 << r;
479 self.eob_run += self.get_bits(r as u8);
480 self.eob_run -= 1;
481 break;
482 }
483
484 k += 16;
485 }
486 }
487
488 if k > self.spec_end as usize {
489 break 'block;
490 }
491 }
492 return Ok(true);
493 }
494 #[allow(clippy::too_many_lines, clippy::op_ref)]
495 pub(crate) fn decode_mcu_ac_refine<T>(
496 &mut self, reader: &mut ZByteReader<T>, table: &HuffmanTable, block: &mut [i16; 64]
497 ) -> Result<bool, DecodeErrors>
498 where
499 T: ZReaderTrait
500 {
501 let bit = (1 << self.successive_low) as i16;
502
503 let mut k = self.spec_start;
504 let (mut symbol, mut r);
505
506 if self.eob_run == 0 {
507 'no_eob: loop {
508 // Decode a coefficient from the bit stream
509 self.refill(reader)?;
510
511 symbol = self.peek_bits::<HUFF_LOOKAHEAD>();
512 symbol = table.lookup[symbol as usize];
513
514 decode_huff!(self, symbol, table);
515
516 r = symbol >> 4;
517 symbol &= 15;
518
519 if symbol == 0 {
520 if r != 15 {
521 // EOB run is 2^r + bits
522 self.eob_run = 1 << r;
523 self.eob_run += self.get_bits(r as u8);
524 // EOB runs are handled by the eob logic
525 break 'no_eob;
526 }
527 } else {
528 if symbol != 1 {
529 return Err(DecodeErrors::HuffmanDecode(
530 "Bad Huffman code, corrupt JPEG?".to_string()
531 ));
532 }
533 // get sign bit
534 // We assume we have enough bits, which should be correct for sane images
535 // since we refill by 32 above
536 if self.get_bit() == 1 {
537 symbol = i32::from(bit);
538 } else {
539 symbol = i32::from(-bit);
540 }
541 }
542
543 // Advance over already nonzero coefficients appending
544 // correction bits to the non-zeroes.
545 // A correction bit is 1 if the absolute value of the coefficient must be increased
546
547 if k <= self.spec_end {
548 'advance_nonzero: loop {
549 let coefficient = &mut block[UN_ZIGZAG[k as usize & 63] & 63];
550
551 if *coefficient != 0 {
552 if self.get_bit() == 1 && (*coefficient & bit) == 0 {
553 if *coefficient >= 0 {
554 *coefficient += bit;
555 } else {
556 *coefficient -= bit;
557 }
558 }
559
560 if self.bits_left < 1 {
561 self.refill(reader)?;
562 }
563 } else {
564 r -= 1;
565
566 if r < 0 {
567 // reached target zero coefficient.
568 break 'advance_nonzero;
569 }
570 };
571
572 if k == self.spec_end {
573 break 'advance_nonzero;
574 }
575
576 k += 1;
577 }
578 }
579
580 if symbol != 0 {
581 let pos = UN_ZIGZAG[k as usize & 63];
582 // output new non-zero coefficient.
583 block[pos & 63] = symbol as i16;
584 }
585
586 k += 1;
587
588 if k > self.spec_end {
589 break 'no_eob;
590 }
591 }
592 }
593 if self.eob_run > 0 {
594 // only run if block does not consists of purely zeroes
595 if &block[1..] != &[0; 63] {
596 self.refill(reader)?;
597
598 while k <= self.spec_end {
599 let coefficient = &mut block[UN_ZIGZAG[k as usize & 63] & 63];
600
601 if *coefficient != 0 && self.get_bit() == 1 {
602 // check if we already modified it, if so do nothing, otherwise
603 // append the correction bit.
604 if (*coefficient & bit) == 0 {
605 if *coefficient >= 0 {
606 *coefficient = coefficient.wrapping_add(bit);
607 } else {
608 *coefficient = coefficient.wrapping_sub(bit);
609 }
610 }
611 }
612 if self.bits_left < 1 {
613 // refill at the last possible moment
614 self.refill(reader)?;
615 }
616 k += 1;
617 }
618 }
619 // count a block completed in EOB run
620 self.eob_run -= 1;
621 }
622 return Ok(true);
623 }
624
625 pub fn update_progressive_params(&mut self, ah: u8, al: u8, spec_start: u8, spec_end: u8) {
626 self.successive_high = ah;
627 self.successive_low = al;
628 self.spec_start = spec_start;
629 self.spec_end = spec_end;
630 }
631
632 /// Reset the stream if we have a restart marker
633 ///
634 /// Restart markers indicate drop those bits in the stream and zero out
635 /// everything
636 #[cold]
637 pub fn reset(&mut self) {
638 self.bits_left = 0;
639 self.marker = None;
640 self.buffer = 0;
641 self.aligned_buffer = 0;
642 self.eob_run = 0;
643 }
644}
645
646/// Do the equivalent of JPEG HUFF_EXTEND
647#[inline(always)]
648fn huff_extend(x: i32, s: i32) -> i32 {
649 // if x<s return x else return x+offset[s] where offset[s] = ( (-1<<s)+1)
650 (x) + ((((x) - (1 << ((s) - 1))) >> 31) & (((-1) << (s)) + 1))
651}
652
653fn has_zero(v: u32) -> bool {
654 // Retrieved from Stanford bithacks
655 // @ https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
656 return !((((v & 0x7F7F_7F7F) + 0x7F7F_7F7F) | v) | 0x7F7F_7F7F) != 0;
657}
658
659fn has_byte(b: u32, val: u8) -> bool {
660 // Retrieved from Stanford bithacks
661 // @ https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
662 has_zero(b ^ ((!0_u32 / 255) * u32::from(val)))
663}