zune_jpeg/
mcu_prog.rs

1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9//!Routines for progressive decoding
10/*
11This file is needlessly complicated,
12
13It is that way to ensure we don't burn memory anyhow
14
15Memory is a scarce resource in some environments, I would like this to be viable
16in such environments
17
18Half of the complexity comes from the jpeg spec, because progressive decoding,
19is one hell of a ride.
20
21*/
22use alloc::string::ToString;
23use alloc::vec::Vec;
24use alloc::{format, vec};
25use core::cmp::min;
26
27use zune_core::bytestream::{ZByteReader, ZReaderTrait};
28use zune_core::colorspace::ColorSpace;
29use zune_core::log::{debug, error, warn};
30
31use crate::bitstream::BitStream;
32use crate::components::{ComponentID, SampleRatios};
33use crate::decoder::{JpegDecoder, MAX_COMPONENTS};
34use crate::errors::DecodeErrors;
35use crate::errors::DecodeErrors::Format;
36use crate::headers::{parse_huffman, parse_sos};
37use crate::marker::Marker;
38use crate::mcu::DCT_BLOCK;
39use crate::misc::{calculate_padded_width, setup_component_params};
40
41impl<T: ZReaderTrait> JpegDecoder<T> {
42    /// Decode a progressive image
43    ///
44    /// This routine decodes a progressive image, stopping if it finds any error.
45    #[allow(
46        clippy::needless_range_loop,
47        clippy::cast_sign_loss,
48        clippy::redundant_else,
49        clippy::too_many_lines
50    )]
51    #[inline(never)]
52    pub(crate) fn decode_mcu_ycbcr_progressive(
53        &mut self, pixels: &mut [u8]
54    ) -> Result<(), DecodeErrors> {
55        setup_component_params(self)?;
56
57        let mut mcu_height;
58
59        // memory location for decoded pixels for components
60        let mut block: [Vec<i16>; MAX_COMPONENTS] = [vec![], vec![], vec![], vec![]];
61        let mut mcu_width;
62
63        let mut seen_scans = 1;
64
65        if self.input_colorspace == ColorSpace::Luma && self.is_interleaved {
66            warn!("Grayscale image with down-sampled component, resetting component details");
67            self.reset_params();
68        }
69
70        if self.is_interleaved {
71            // this helps us catch component errors.
72            self.set_upsampling()?;
73        }
74        if self.is_interleaved {
75            mcu_width = self.mcu_x;
76            mcu_height = self.mcu_y;
77        } else {
78            mcu_width = (self.info.width as usize + 7) / 8;
79            mcu_height = (self.info.height as usize + 7) / 8;
80        }
81        if self.is_interleaved
82            && self.input_colorspace.num_components() > 1
83            && self.options.jpeg_get_out_colorspace().num_components() == 1
84            && (self.sub_sample_ratio == SampleRatios::V
85                || self.sub_sample_ratio == SampleRatios::HV)
86        {
87            // For a specific set of images, e.g interleaved,
88            // when converting from YcbCr to grayscale, we need to
89            // take into account mcu height since the MCU decoding needs to take
90            // it into account for padding purposes and the post processor
91            // parses two rows per mcu width.
92            //
93            // set coeff to be 2 to ensure that we increment two rows
94            // for every mcu processed also
95            mcu_height *= self.v_max;
96            mcu_height /= self.h_max;
97            self.coeff = 2;
98        }
99
100        mcu_width *= 64;
101
102        if self.input_colorspace.num_components() > self.components.len() {
103            let msg = format!(
104                " Expected {} number of components but found {}",
105                self.input_colorspace.num_components(),
106                self.components.len()
107            );
108            return Err(DecodeErrors::Format(msg));
109        }
110        for i in 0..self.input_colorspace.num_components() {
111            let comp = &self.components[i];
112            let len = mcu_width * comp.vertical_sample * comp.horizontal_sample * mcu_height;
113
114            block[i] = vec![0; len];
115        }
116
117        let mut stream = BitStream::new_progressive(
118            self.succ_high,
119            self.succ_low,
120            self.spec_start,
121            self.spec_end
122        );
123
124        // there are multiple scans in the stream, this should resolve the first scan
125        self.parse_entropy_coded_data(&mut stream, &mut block)?;
126
127        // extract marker
128        let mut marker = stream
129            .marker
130            .take()
131            .ok_or(DecodeErrors::FormatStatic("Marker missing where expected"))?;
132
133        // if marker is EOI, we are done, otherwise continue scanning.
134        //
135        // In case we have a premature image, we print a warning or return
136        // an error, depending on the strictness of the decoder, so there
137        // is that logic to handle too
138        'eoi: while marker != Marker::EOI {
139            match marker {
140                Marker::DHT => {
141                    parse_huffman(self)?;
142                }
143                Marker::SOS => {
144                    parse_sos(self)?;
145
146                    stream.update_progressive_params(
147                        self.succ_high,
148                        self.succ_low,
149                        self.spec_start,
150                        self.spec_end
151                    );
152
153                    // after every SOS, marker, parse data for that scan.
154                    self.parse_entropy_coded_data(&mut stream, &mut block)?;
155                    // extract marker, might either indicate end of image or we continue
156                    // scanning(hence the continue statement to determine).
157                    match get_marker(&mut self.stream, &mut stream) {
158                        Ok(marker_n) => {
159                            marker = marker_n;
160                            seen_scans += 1;
161                            if seen_scans > self.options.jpeg_get_max_scans() {
162                                return Err(DecodeErrors::Format(format!(
163                                    "Too many scans, exceeded limit of {}",
164                                    self.options.jpeg_get_max_scans()
165                                )));
166                            }
167
168                            stream.reset();
169                            continue 'eoi;
170                        }
171                        Err(msg) => {
172                            if self.options.get_strict_mode() {
173                                return Err(msg);
174                            }
175                            error!("{:?}", msg);
176                            break 'eoi;
177                        }
178                    }
179                }
180                _ => {
181                    break 'eoi;
182                }
183            }
184
185            match get_marker(&mut self.stream, &mut stream) {
186                Ok(marker_n) => {
187                    marker = marker_n;
188                }
189                Err(e) => {
190                    if self.options.get_strict_mode() {
191                        return Err(e);
192                    }
193                    error!("{}", e);
194                }
195            }
196        }
197
198        self.finish_progressive_decoding(&block, mcu_width, pixels)
199    }
200
201    #[allow(clippy::too_many_lines, clippy::cast_sign_loss)]
202    fn parse_entropy_coded_data(
203        &mut self, stream: &mut BitStream, buffer: &mut [Vec<i16>; MAX_COMPONENTS]
204    ) -> Result<(), DecodeErrors> {
205        stream.reset();
206        self.components.iter_mut().for_each(|x| x.dc_pred = 0);
207
208        if usize::from(self.num_scans) > self.input_colorspace.num_components() {
209            return Err(Format(format!(
210                "Number of scans {} cannot be greater than number of components, {}",
211                self.num_scans,
212                self.input_colorspace.num_components()
213            )));
214        }
215
216        if self.num_scans == 1 {
217            // Safety checks
218            if self.spec_end != 0 && self.spec_start == 0 {
219                return Err(DecodeErrors::FormatStatic(
220                    "Can't merge DC and AC corrupt jpeg"
221                ));
222            }
223            // non interleaved data, process one block at a time in trivial scanline order
224
225            let k = self.z_order[0];
226
227            if k >= self.components.len() {
228                return Err(DecodeErrors::Format(format!(
229                    "Cannot find component {k}, corrupt image"
230                )));
231            }
232
233            let (mcu_width, mcu_height);
234
235            if self.components[k].component_id == ComponentID::Y
236                && (self.components[k].vertical_sample != 1
237                    || self.components[k].horizontal_sample != 1)
238                || !self.is_interleaved
239            {
240                // For Y channel  or non interleaved scans ,
241                // mcu's is the image dimensions divided by 8
242                mcu_width = ((self.info.width + 7) / 8) as usize;
243                mcu_height = ((self.info.height + 7) / 8) as usize;
244            } else {
245                // For other channels, in an interleaved mcu, number of MCU's
246                // are determined by some weird maths done in headers.rs->parse_sos()
247                mcu_width = self.mcu_x;
248                mcu_height = self.mcu_y;
249            }
250
251            for i in 0..mcu_height {
252                for j in 0..mcu_width {
253                    if self.spec_start != 0 && self.succ_high == 0 && stream.eob_run > 0 {
254                        // handle EOB runs here.
255                        stream.eob_run -= 1;
256                        continue;
257                    }
258                    let start = 64 * (j + i * (self.components[k].width_stride / 8));
259
260                    let data: &mut [i16; 64] = buffer
261                        .get_mut(k)
262                        .unwrap()
263                        .get_mut(start..start + 64)
264                        .unwrap()
265                        .try_into()
266                        .unwrap();
267
268                    if self.spec_start == 0 {
269                        let pos = self.components[k].dc_huff_table & (MAX_COMPONENTS - 1);
270                        let dc_table = self
271                            .dc_huffman_tables
272                            .get(pos)
273                            .ok_or(DecodeErrors::FormatStatic(
274                                "No huffman table for DC component"
275                            ))?
276                            .as_ref()
277                            .ok_or(DecodeErrors::FormatStatic(
278                                "Huffman table at index  {} not initialized"
279                            ))?;
280
281                        let dc_pred = &mut self.components[k].dc_pred;
282
283                        if self.succ_high == 0 {
284                            // first scan for this mcu
285                            stream.decode_prog_dc_first(
286                                &mut self.stream,
287                                dc_table,
288                                &mut data[0],
289                                dc_pred
290                            )?;
291                        } else {
292                            // refining scans for this MCU
293                            stream.decode_prog_dc_refine(&mut self.stream, &mut data[0])?;
294                        }
295                    } else {
296                        let pos = self.components[k].ac_huff_table;
297                        let ac_table = self
298                            .ac_huffman_tables
299                            .get(pos)
300                            .ok_or_else(|| {
301                                DecodeErrors::Format(format!(
302                                    "No huffman table for component:{pos}"
303                                ))
304                            })?
305                            .as_ref()
306                            .ok_or_else(|| {
307                                DecodeErrors::Format(format!(
308                                    "Huffman table at index  {pos} not initialized"
309                                ))
310                            })?;
311
312                        if self.succ_high == 0 {
313                            debug_assert!(stream.eob_run == 0, "EOB run is not zero");
314
315                            stream.decode_mcu_ac_first(&mut self.stream, ac_table, data)?;
316                        } else {
317                            // refinement scan
318                            stream.decode_mcu_ac_refine(&mut self.stream, ac_table, data)?;
319                        }
320                    }
321                    // + EOB and investigate effect.
322                    self.todo -= 1;
323
324                    if self.todo == 0 {
325                        self.handle_rst(stream)?;
326                    }
327                }
328            }
329        } else {
330            if self.spec_end != 0 {
331                return Err(DecodeErrors::HuffmanDecode(
332                    "Can't merge dc and AC corrupt jpeg".to_string()
333                ));
334            }
335            // process scan n elements in order
336
337            // Do the error checking with allocs here.
338            // Make the one in the inner loop free of allocations.
339            for k in 0..self.num_scans {
340                let n = self.z_order[k as usize];
341
342                if n >= self.components.len() {
343                    return Err(DecodeErrors::Format(format!(
344                        "Cannot find component {n}, corrupt image"
345                    )));
346                }
347
348                let component = &mut self.components[n];
349                let _ = self
350                    .dc_huffman_tables
351                    .get(component.dc_huff_table)
352                    .ok_or_else(|| {
353                        DecodeErrors::Format(format!(
354                            "No huffman table for component:{}",
355                            component.dc_huff_table
356                        ))
357                    })?
358                    .as_ref()
359                    .ok_or_else(|| {
360                        DecodeErrors::Format(format!(
361                            "Huffman table at index  {} not initialized",
362                            component.dc_huff_table
363                        ))
364                    })?;
365            }
366            // Interleaved scan
367
368            // Components shall not be interleaved in progressive mode, except for
369            // the DC coefficients in the first scan for each component of a progressive frame.
370            for i in 0..self.mcu_y {
371                for j in 0..self.mcu_x {
372                    // process scan n elements in order
373                    for k in 0..self.num_scans {
374                        let n = self.z_order[k as usize];
375                        let component = &mut self.components[n];
376                        let huff_table = self
377                            .dc_huffman_tables
378                            .get(component.dc_huff_table)
379                            .ok_or(DecodeErrors::FormatStatic("No huffman table for component"))?
380                            .as_ref()
381                            .ok_or(DecodeErrors::FormatStatic(
382                                "Huffman table at index not initialized"
383                            ))?;
384
385                        for v_samp in 0..component.vertical_sample {
386                            for h_samp in 0..component.horizontal_sample {
387                                let x2 = j * component.horizontal_sample + h_samp;
388                                let y2 = i * component.vertical_sample + v_samp;
389                                let position = 64 * (x2 + y2 * component.width_stride / 8);
390
391                                let data = &mut buffer[n][position];
392
393                                if self.succ_high == 0 {
394                                    stream.decode_prog_dc_first(
395                                        &mut self.stream,
396                                        huff_table,
397                                        data,
398                                        &mut component.dc_pred
399                                    )?;
400                                } else {
401                                    stream.decode_prog_dc_refine(&mut self.stream, data)?;
402                                }
403                            }
404                        }
405                    }
406                    // We want wrapping subtraction here because it means
407                    // we get a higher number in the case this underflows
408                    self.todo = self.todo.wrapping_sub(1);
409                    // after every scan that's a mcu, count down restart markers.
410                    if self.todo == 0 {
411                        self.handle_rst(stream)?;
412                    }
413                }
414            }
415        }
416        return Ok(());
417    }
418
419    #[allow(clippy::too_many_lines)]
420    #[allow(clippy::needless_range_loop, clippy::cast_sign_loss)]
421    fn finish_progressive_decoding(
422        &mut self, block: &[Vec<i16>; MAX_COMPONENTS], _mcu_width: usize, pixels: &mut [u8]
423    ) -> Result<(), DecodeErrors> {
424        // This function is complicated because we need to replicate
425        // the function in mcu.rs
426        //
427        // The advantage is that we do very little allocation and very lot
428        // channel reusing.
429        // The trick is to notice that we repeat the same procedure per MCU
430        // width.
431        //
432        // So we can set it up that we only allocate temporary storage large enough
433        // to store a single mcu width, then reuse it per invocation.
434        //
435        // This is advantageous to us.
436        //
437        // Remember we need to have the whole MCU buffer so we store 3 unprocessed
438        // channels in memory, and then we allocate the whole output buffer in memory, both of
439        // which are huge.
440        //
441        //
442
443        let mcu_height = if self.is_interleaved {
444            self.mcu_y
445        } else {
446            // For non-interleaved images( (1*1) subsampling)
447            // number of MCU's are the widths (+7 to account for paddings) divided by 8.
448            ((self.info.height + 7) / 8) as usize
449        };
450
451        // Size of our output image(width*height)
452        let is_hv = usize::from(self.is_interleaved);
453        let upsampler_scratch_size = is_hv * self.components[0].width_stride;
454        let width = usize::from(self.info.width);
455        let padded_width = calculate_padded_width(width, self.sub_sample_ratio);
456
457        //let mut pixels = vec![0; capacity * out_colorspace_components];
458        let mut upsampler_scratch_space = vec![0; upsampler_scratch_size];
459        let mut tmp = [0_i32; DCT_BLOCK];
460
461        for (pos, comp) in self.components.iter_mut().enumerate() {
462            // Allocate only needed components.
463            //
464            // For special colorspaces i.e YCCK and CMYK, just allocate all of the needed
465            // components.
466            if min(
467                self.options.jpeg_get_out_colorspace().num_components() - 1,
468                pos
469            ) == pos
470                || self.input_colorspace == ColorSpace::YCCK
471                || self.input_colorspace == ColorSpace::CMYK
472            {
473                // allocate enough space to hold a whole MCU width
474                // this means we should take into account sampling ratios
475                // `*8` is because each MCU spans 8 widths.
476                let len = comp.width_stride * comp.vertical_sample * 8;
477
478                comp.needed = true;
479                comp.raw_coeff = vec![0; len];
480            } else {
481                comp.needed = false;
482            }
483        }
484
485        let mut pixels_written = 0;
486
487        // dequantize, idct and color convert.
488        for i in 0..mcu_height {
489            'component: for (position, component) in &mut self.components.iter_mut().enumerate() {
490                if !component.needed {
491                    continue 'component;
492                }
493                let qt_table = &component.quantization_table;
494
495                // step is the number of pixels this iteration wil be handling
496                // Given by the number of mcu's height and the length of the component block
497                // Since the component block contains the whole channel as raw pixels
498                // we this evenly divides the pixels into MCU blocks
499                //
500                // For interleaved images, this gives us the exact pixels comprising a whole MCU
501                // block
502                let step = block[position].len() / mcu_height;
503                // where we will be reading our pixels from.
504                let start = i * step;
505
506                let slice = &block[position][start..start + step];
507
508                let temp_channel = &mut component.raw_coeff;
509
510                // The next logical step is to iterate width wise.
511                // To figure out how many pixels we iterate by we use effective pixels
512                // Given to us by component.x
513                // iterate per effective pixels.
514                let mcu_x = component.width_stride / 8;
515
516                // iterate per every vertical sample.
517                for k in 0..component.vertical_sample {
518                    for j in 0..mcu_x {
519                        // after writing a single stride, we need to skip 8 rows.
520                        // This does the row calculation
521                        let width_stride = k * 8 * component.width_stride;
522                        let start = j * 64 + width_stride;
523
524                        // dequantize
525                        for ((x, out), qt_val) in slice[start..start + 64]
526                            .iter()
527                            .zip(tmp.iter_mut())
528                            .zip(qt_table.iter())
529                        {
530                            *out = i32::from(*x) * qt_val;
531                        }
532                        // determine where to write.
533                        let sl = &mut temp_channel[component.idct_pos..];
534
535                        component.idct_pos += 8;
536                        // tmp now contains a dequantized block so idct it
537                        (self.idct_func)(&mut tmp, sl, component.width_stride);
538                    }
539                    // after every write of 8, skip 7 since idct write stride wise 8 times.
540                    //
541                    // Remember each MCU is 8x8 block, so each idct will write 8 strides into
542                    // sl
543                    //
544                    // and component.idct_pos is one stride long
545                    component.idct_pos += 7 * component.width_stride;
546                }
547                component.idct_pos = 0;
548            }
549
550            // process that width up until it's impossible
551            self.post_process(
552                pixels,
553                i,
554                mcu_height,
555                width,
556                padded_width,
557                &mut pixels_written,
558                &mut upsampler_scratch_space
559            )?;
560        }
561
562        debug!("Finished decoding image");
563
564        return Ok(());
565    }
566    pub(crate) fn reset_params(&mut self) {
567        /*
568        Apparently, grayscale images which can be down sampled exists, which is weird in the sense
569        that it has one component Y, which is not usually down sampled.
570
571        This means some calculations will be wrong, so for that we explicitly reset params
572        for such occurrences, warn and reset the image info to appear as if it were
573        a non-sampled image to ensure decoding works
574        */
575        self.h_max = 1;
576        self.options = self.options.jpeg_set_out_colorspace(ColorSpace::Luma);
577        self.v_max = 1;
578        self.sub_sample_ratio = SampleRatios::None;
579        self.is_interleaved = false;
580        self.components[0].vertical_sample = 1;
581        self.components[0].width_stride = (((self.info.width as usize) + 7) / 8) * 8;
582        self.components[0].horizontal_sample = 1;
583    }
584}
585
586///Get a marker from the bit-stream.
587///
588/// This reads until it gets a marker or end of file is encountered
589fn get_marker<T>(
590    reader: &mut ZByteReader<T>, stream: &mut BitStream
591) -> Result<Marker, DecodeErrors>
592where
593    T: ZReaderTrait
594{
595    if let Some(marker) = stream.marker {
596        stream.marker = None;
597        return Ok(marker);
598    }
599
600    // read until we get a marker
601
602    while !reader.eof() {
603        let marker = reader.get_u8_err()?;
604
605        if marker == 255 {
606            let mut r = reader.get_u8_err()?;
607            // 0xFF 0XFF(some images may be like that)
608            while r == 0xFF {
609                r = reader.get_u8_err()?;
610            }
611
612            if r != 0 {
613                return Marker::from_u8(r)
614                    .ok_or_else(|| DecodeErrors::Format(format!("Unknown marker 0xFF{r:X}")));
615            }
616        }
617    }
618    return Err(DecodeErrors::ExhaustedData);
619}