zune_jpeg/mcu_prog.rs
1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9//!Routines for progressive decoding
10/*
11This file is needlessly complicated,
12
13It is that way to ensure we don't burn memory anyhow
14
15Memory is a scarce resource in some environments, I would like this to be viable
16in such environments
17
18Half of the complexity comes from the jpeg spec, because progressive decoding,
19is one hell of a ride.
20
21*/
22use alloc::string::ToString;
23use alloc::vec::Vec;
24use alloc::{format, vec};
25use core::cmp::min;
26
27use zune_core::bytestream::{ZByteReader, ZReaderTrait};
28use zune_core::colorspace::ColorSpace;
29use zune_core::log::{debug, error, warn};
30
31use crate::bitstream::BitStream;
32use crate::components::{ComponentID, SampleRatios};
33use crate::decoder::{JpegDecoder, MAX_COMPONENTS};
34use crate::errors::DecodeErrors;
35use crate::errors::DecodeErrors::Format;
36use crate::headers::{parse_huffman, parse_sos};
37use crate::marker::Marker;
38use crate::mcu::DCT_BLOCK;
39use crate::misc::{calculate_padded_width, setup_component_params};
40
41impl<T: ZReaderTrait> JpegDecoder<T> {
42 /// Decode a progressive image
43 ///
44 /// This routine decodes a progressive image, stopping if it finds any error.
45 #[allow(
46 clippy::needless_range_loop,
47 clippy::cast_sign_loss,
48 clippy::redundant_else,
49 clippy::too_many_lines
50 )]
51 #[inline(never)]
52 pub(crate) fn decode_mcu_ycbcr_progressive(
53 &mut self, pixels: &mut [u8]
54 ) -> Result<(), DecodeErrors> {
55 setup_component_params(self)?;
56
57 let mut mcu_height;
58
59 // memory location for decoded pixels for components
60 let mut block: [Vec<i16>; MAX_COMPONENTS] = [vec![], vec![], vec![], vec![]];
61 let mut mcu_width;
62
63 let mut seen_scans = 1;
64
65 if self.input_colorspace == ColorSpace::Luma && self.is_interleaved {
66 warn!("Grayscale image with down-sampled component, resetting component details");
67 self.reset_params();
68 }
69
70 if self.is_interleaved {
71 // this helps us catch component errors.
72 self.set_upsampling()?;
73 }
74 if self.is_interleaved {
75 mcu_width = self.mcu_x;
76 mcu_height = self.mcu_y;
77 } else {
78 mcu_width = (self.info.width as usize + 7) / 8;
79 mcu_height = (self.info.height as usize + 7) / 8;
80 }
81 if self.is_interleaved
82 && self.input_colorspace.num_components() > 1
83 && self.options.jpeg_get_out_colorspace().num_components() == 1
84 && (self.sub_sample_ratio == SampleRatios::V
85 || self.sub_sample_ratio == SampleRatios::HV)
86 {
87 // For a specific set of images, e.g interleaved,
88 // when converting from YcbCr to grayscale, we need to
89 // take into account mcu height since the MCU decoding needs to take
90 // it into account for padding purposes and the post processor
91 // parses two rows per mcu width.
92 //
93 // set coeff to be 2 to ensure that we increment two rows
94 // for every mcu processed also
95 mcu_height *= self.v_max;
96 mcu_height /= self.h_max;
97 self.coeff = 2;
98 }
99
100 mcu_width *= 64;
101
102 if self.input_colorspace.num_components() > self.components.len() {
103 let msg = format!(
104 " Expected {} number of components but found {}",
105 self.input_colorspace.num_components(),
106 self.components.len()
107 );
108 return Err(DecodeErrors::Format(msg));
109 }
110 for i in 0..self.input_colorspace.num_components() {
111 let comp = &self.components[i];
112 let len = mcu_width * comp.vertical_sample * comp.horizontal_sample * mcu_height;
113
114 block[i] = vec![0; len];
115 }
116
117 let mut stream = BitStream::new_progressive(
118 self.succ_high,
119 self.succ_low,
120 self.spec_start,
121 self.spec_end
122 );
123
124 // there are multiple scans in the stream, this should resolve the first scan
125 self.parse_entropy_coded_data(&mut stream, &mut block)?;
126
127 // extract marker
128 let mut marker = stream
129 .marker
130 .take()
131 .ok_or(DecodeErrors::FormatStatic("Marker missing where expected"))?;
132
133 // if marker is EOI, we are done, otherwise continue scanning.
134 //
135 // In case we have a premature image, we print a warning or return
136 // an error, depending on the strictness of the decoder, so there
137 // is that logic to handle too
138 'eoi: while marker != Marker::EOI {
139 match marker {
140 Marker::DHT => {
141 parse_huffman(self)?;
142 }
143 Marker::SOS => {
144 parse_sos(self)?;
145
146 stream.update_progressive_params(
147 self.succ_high,
148 self.succ_low,
149 self.spec_start,
150 self.spec_end
151 );
152
153 // after every SOS, marker, parse data for that scan.
154 self.parse_entropy_coded_data(&mut stream, &mut block)?;
155 // extract marker, might either indicate end of image or we continue
156 // scanning(hence the continue statement to determine).
157 match get_marker(&mut self.stream, &mut stream) {
158 Ok(marker_n) => {
159 marker = marker_n;
160 seen_scans += 1;
161 if seen_scans > self.options.jpeg_get_max_scans() {
162 return Err(DecodeErrors::Format(format!(
163 "Too many scans, exceeded limit of {}",
164 self.options.jpeg_get_max_scans()
165 )));
166 }
167
168 stream.reset();
169 continue 'eoi;
170 }
171 Err(msg) => {
172 if self.options.get_strict_mode() {
173 return Err(msg);
174 }
175 error!("{:?}", msg);
176 break 'eoi;
177 }
178 }
179 }
180 _ => {
181 break 'eoi;
182 }
183 }
184
185 match get_marker(&mut self.stream, &mut stream) {
186 Ok(marker_n) => {
187 marker = marker_n;
188 }
189 Err(e) => {
190 if self.options.get_strict_mode() {
191 return Err(e);
192 }
193 error!("{}", e);
194 }
195 }
196 }
197
198 self.finish_progressive_decoding(&block, mcu_width, pixels)
199 }
200
201 #[allow(clippy::too_many_lines, clippy::cast_sign_loss)]
202 fn parse_entropy_coded_data(
203 &mut self, stream: &mut BitStream, buffer: &mut [Vec<i16>; MAX_COMPONENTS]
204 ) -> Result<(), DecodeErrors> {
205 stream.reset();
206 self.components.iter_mut().for_each(|x| x.dc_pred = 0);
207
208 if usize::from(self.num_scans) > self.input_colorspace.num_components() {
209 return Err(Format(format!(
210 "Number of scans {} cannot be greater than number of components, {}",
211 self.num_scans,
212 self.input_colorspace.num_components()
213 )));
214 }
215
216 if self.num_scans == 1 {
217 // Safety checks
218 if self.spec_end != 0 && self.spec_start == 0 {
219 return Err(DecodeErrors::FormatStatic(
220 "Can't merge DC and AC corrupt jpeg"
221 ));
222 }
223 // non interleaved data, process one block at a time in trivial scanline order
224
225 let k = self.z_order[0];
226
227 if k >= self.components.len() {
228 return Err(DecodeErrors::Format(format!(
229 "Cannot find component {k}, corrupt image"
230 )));
231 }
232
233 let (mcu_width, mcu_height);
234
235 if self.components[k].component_id == ComponentID::Y
236 && (self.components[k].vertical_sample != 1
237 || self.components[k].horizontal_sample != 1)
238 || !self.is_interleaved
239 {
240 // For Y channel or non interleaved scans ,
241 // mcu's is the image dimensions divided by 8
242 mcu_width = ((self.info.width + 7) / 8) as usize;
243 mcu_height = ((self.info.height + 7) / 8) as usize;
244 } else {
245 // For other channels, in an interleaved mcu, number of MCU's
246 // are determined by some weird maths done in headers.rs->parse_sos()
247 mcu_width = self.mcu_x;
248 mcu_height = self.mcu_y;
249 }
250
251 for i in 0..mcu_height {
252 for j in 0..mcu_width {
253 if self.spec_start != 0 && self.succ_high == 0 && stream.eob_run > 0 {
254 // handle EOB runs here.
255 stream.eob_run -= 1;
256 continue;
257 }
258 let start = 64 * (j + i * (self.components[k].width_stride / 8));
259
260 let data: &mut [i16; 64] = buffer
261 .get_mut(k)
262 .unwrap()
263 .get_mut(start..start + 64)
264 .unwrap()
265 .try_into()
266 .unwrap();
267
268 if self.spec_start == 0 {
269 let pos = self.components[k].dc_huff_table & (MAX_COMPONENTS - 1);
270 let dc_table = self
271 .dc_huffman_tables
272 .get(pos)
273 .ok_or(DecodeErrors::FormatStatic(
274 "No huffman table for DC component"
275 ))?
276 .as_ref()
277 .ok_or(DecodeErrors::FormatStatic(
278 "Huffman table at index {} not initialized"
279 ))?;
280
281 let dc_pred = &mut self.components[k].dc_pred;
282
283 if self.succ_high == 0 {
284 // first scan for this mcu
285 stream.decode_prog_dc_first(
286 &mut self.stream,
287 dc_table,
288 &mut data[0],
289 dc_pred
290 )?;
291 } else {
292 // refining scans for this MCU
293 stream.decode_prog_dc_refine(&mut self.stream, &mut data[0])?;
294 }
295 } else {
296 let pos = self.components[k].ac_huff_table;
297 let ac_table = self
298 .ac_huffman_tables
299 .get(pos)
300 .ok_or_else(|| {
301 DecodeErrors::Format(format!(
302 "No huffman table for component:{pos}"
303 ))
304 })?
305 .as_ref()
306 .ok_or_else(|| {
307 DecodeErrors::Format(format!(
308 "Huffman table at index {pos} not initialized"
309 ))
310 })?;
311
312 if self.succ_high == 0 {
313 debug_assert!(stream.eob_run == 0, "EOB run is not zero");
314
315 stream.decode_mcu_ac_first(&mut self.stream, ac_table, data)?;
316 } else {
317 // refinement scan
318 stream.decode_mcu_ac_refine(&mut self.stream, ac_table, data)?;
319 }
320 }
321 // + EOB and investigate effect.
322 self.todo -= 1;
323
324 if self.todo == 0 {
325 self.handle_rst(stream)?;
326 }
327 }
328 }
329 } else {
330 if self.spec_end != 0 {
331 return Err(DecodeErrors::HuffmanDecode(
332 "Can't merge dc and AC corrupt jpeg".to_string()
333 ));
334 }
335 // process scan n elements in order
336
337 // Do the error checking with allocs here.
338 // Make the one in the inner loop free of allocations.
339 for k in 0..self.num_scans {
340 let n = self.z_order[k as usize];
341
342 if n >= self.components.len() {
343 return Err(DecodeErrors::Format(format!(
344 "Cannot find component {n}, corrupt image"
345 )));
346 }
347
348 let component = &mut self.components[n];
349 let _ = self
350 .dc_huffman_tables
351 .get(component.dc_huff_table)
352 .ok_or_else(|| {
353 DecodeErrors::Format(format!(
354 "No huffman table for component:{}",
355 component.dc_huff_table
356 ))
357 })?
358 .as_ref()
359 .ok_or_else(|| {
360 DecodeErrors::Format(format!(
361 "Huffman table at index {} not initialized",
362 component.dc_huff_table
363 ))
364 })?;
365 }
366 // Interleaved scan
367
368 // Components shall not be interleaved in progressive mode, except for
369 // the DC coefficients in the first scan for each component of a progressive frame.
370 for i in 0..self.mcu_y {
371 for j in 0..self.mcu_x {
372 // process scan n elements in order
373 for k in 0..self.num_scans {
374 let n = self.z_order[k as usize];
375 let component = &mut self.components[n];
376 let huff_table = self
377 .dc_huffman_tables
378 .get(component.dc_huff_table)
379 .ok_or(DecodeErrors::FormatStatic("No huffman table for component"))?
380 .as_ref()
381 .ok_or(DecodeErrors::FormatStatic(
382 "Huffman table at index not initialized"
383 ))?;
384
385 for v_samp in 0..component.vertical_sample {
386 for h_samp in 0..component.horizontal_sample {
387 let x2 = j * component.horizontal_sample + h_samp;
388 let y2 = i * component.vertical_sample + v_samp;
389 let position = 64 * (x2 + y2 * component.width_stride / 8);
390
391 let data = &mut buffer[n][position];
392
393 if self.succ_high == 0 {
394 stream.decode_prog_dc_first(
395 &mut self.stream,
396 huff_table,
397 data,
398 &mut component.dc_pred
399 )?;
400 } else {
401 stream.decode_prog_dc_refine(&mut self.stream, data)?;
402 }
403 }
404 }
405 }
406 // We want wrapping subtraction here because it means
407 // we get a higher number in the case this underflows
408 self.todo = self.todo.wrapping_sub(1);
409 // after every scan that's a mcu, count down restart markers.
410 if self.todo == 0 {
411 self.handle_rst(stream)?;
412 }
413 }
414 }
415 }
416 return Ok(());
417 }
418
419 #[allow(clippy::too_many_lines)]
420 #[allow(clippy::needless_range_loop, clippy::cast_sign_loss)]
421 fn finish_progressive_decoding(
422 &mut self, block: &[Vec<i16>; MAX_COMPONENTS], _mcu_width: usize, pixels: &mut [u8]
423 ) -> Result<(), DecodeErrors> {
424 // This function is complicated because we need to replicate
425 // the function in mcu.rs
426 //
427 // The advantage is that we do very little allocation and very lot
428 // channel reusing.
429 // The trick is to notice that we repeat the same procedure per MCU
430 // width.
431 //
432 // So we can set it up that we only allocate temporary storage large enough
433 // to store a single mcu width, then reuse it per invocation.
434 //
435 // This is advantageous to us.
436 //
437 // Remember we need to have the whole MCU buffer so we store 3 unprocessed
438 // channels in memory, and then we allocate the whole output buffer in memory, both of
439 // which are huge.
440 //
441 //
442
443 let mcu_height = if self.is_interleaved {
444 self.mcu_y
445 } else {
446 // For non-interleaved images( (1*1) subsampling)
447 // number of MCU's are the widths (+7 to account for paddings) divided by 8.
448 ((self.info.height + 7) / 8) as usize
449 };
450
451 // Size of our output image(width*height)
452 let is_hv = usize::from(self.is_interleaved);
453 let upsampler_scratch_size = is_hv * self.components[0].width_stride;
454 let width = usize::from(self.info.width);
455 let padded_width = calculate_padded_width(width, self.sub_sample_ratio);
456
457 //let mut pixels = vec![0; capacity * out_colorspace_components];
458 let mut upsampler_scratch_space = vec![0; upsampler_scratch_size];
459 let mut tmp = [0_i32; DCT_BLOCK];
460
461 for (pos, comp) in self.components.iter_mut().enumerate() {
462 // Allocate only needed components.
463 //
464 // For special colorspaces i.e YCCK and CMYK, just allocate all of the needed
465 // components.
466 if min(
467 self.options.jpeg_get_out_colorspace().num_components() - 1,
468 pos
469 ) == pos
470 || self.input_colorspace == ColorSpace::YCCK
471 || self.input_colorspace == ColorSpace::CMYK
472 {
473 // allocate enough space to hold a whole MCU width
474 // this means we should take into account sampling ratios
475 // `*8` is because each MCU spans 8 widths.
476 let len = comp.width_stride * comp.vertical_sample * 8;
477
478 comp.needed = true;
479 comp.raw_coeff = vec![0; len];
480 } else {
481 comp.needed = false;
482 }
483 }
484
485 let mut pixels_written = 0;
486
487 // dequantize, idct and color convert.
488 for i in 0..mcu_height {
489 'component: for (position, component) in &mut self.components.iter_mut().enumerate() {
490 if !component.needed {
491 continue 'component;
492 }
493 let qt_table = &component.quantization_table;
494
495 // step is the number of pixels this iteration wil be handling
496 // Given by the number of mcu's height and the length of the component block
497 // Since the component block contains the whole channel as raw pixels
498 // we this evenly divides the pixels into MCU blocks
499 //
500 // For interleaved images, this gives us the exact pixels comprising a whole MCU
501 // block
502 let step = block[position].len() / mcu_height;
503 // where we will be reading our pixels from.
504 let start = i * step;
505
506 let slice = &block[position][start..start + step];
507
508 let temp_channel = &mut component.raw_coeff;
509
510 // The next logical step is to iterate width wise.
511 // To figure out how many pixels we iterate by we use effective pixels
512 // Given to us by component.x
513 // iterate per effective pixels.
514 let mcu_x = component.width_stride / 8;
515
516 // iterate per every vertical sample.
517 for k in 0..component.vertical_sample {
518 for j in 0..mcu_x {
519 // after writing a single stride, we need to skip 8 rows.
520 // This does the row calculation
521 let width_stride = k * 8 * component.width_stride;
522 let start = j * 64 + width_stride;
523
524 // dequantize
525 for ((x, out), qt_val) in slice[start..start + 64]
526 .iter()
527 .zip(tmp.iter_mut())
528 .zip(qt_table.iter())
529 {
530 *out = i32::from(*x) * qt_val;
531 }
532 // determine where to write.
533 let sl = &mut temp_channel[component.idct_pos..];
534
535 component.idct_pos += 8;
536 // tmp now contains a dequantized block so idct it
537 (self.idct_func)(&mut tmp, sl, component.width_stride);
538 }
539 // after every write of 8, skip 7 since idct write stride wise 8 times.
540 //
541 // Remember each MCU is 8x8 block, so each idct will write 8 strides into
542 // sl
543 //
544 // and component.idct_pos is one stride long
545 component.idct_pos += 7 * component.width_stride;
546 }
547 component.idct_pos = 0;
548 }
549
550 // process that width up until it's impossible
551 self.post_process(
552 pixels,
553 i,
554 mcu_height,
555 width,
556 padded_width,
557 &mut pixels_written,
558 &mut upsampler_scratch_space
559 )?;
560 }
561
562 debug!("Finished decoding image");
563
564 return Ok(());
565 }
566 pub(crate) fn reset_params(&mut self) {
567 /*
568 Apparently, grayscale images which can be down sampled exists, which is weird in the sense
569 that it has one component Y, which is not usually down sampled.
570
571 This means some calculations will be wrong, so for that we explicitly reset params
572 for such occurrences, warn and reset the image info to appear as if it were
573 a non-sampled image to ensure decoding works
574 */
575 self.h_max = 1;
576 self.options = self.options.jpeg_set_out_colorspace(ColorSpace::Luma);
577 self.v_max = 1;
578 self.sub_sample_ratio = SampleRatios::None;
579 self.is_interleaved = false;
580 self.components[0].vertical_sample = 1;
581 self.components[0].width_stride = (((self.info.width as usize) + 7) / 8) * 8;
582 self.components[0].horizontal_sample = 1;
583 }
584}
585
586///Get a marker from the bit-stream.
587///
588/// This reads until it gets a marker or end of file is encountered
589fn get_marker<T>(
590 reader: &mut ZByteReader<T>, stream: &mut BitStream
591) -> Result<Marker, DecodeErrors>
592where
593 T: ZReaderTrait
594{
595 if let Some(marker) = stream.marker {
596 stream.marker = None;
597 return Ok(marker);
598 }
599
600 // read until we get a marker
601
602 while !reader.eof() {
603 let marker = reader.get_u8_err()?;
604
605 if marker == 255 {
606 let mut r = reader.get_u8_err()?;
607 // 0xFF 0XFF(some images may be like that)
608 while r == 0xFF {
609 r = reader.get_u8_err()?;
610 }
611
612 if r != 0 {
613 return Marker::from_u8(r)
614 .ok_or_else(|| DecodeErrors::Format(format!("Unknown marker 0xFF{r:X}")));
615 }
616 }
617 }
618 return Err(DecodeErrors::ExhaustedData);
619}