png/decoder/
transform.rs

1//! Transforming a decompressed, unfiltered row into the final output.
2
3mod palette;
4
5use crate::{BitDepth, ColorType, DecodingError, Info, Transformations};
6
7use super::stream::FormatErrorInner;
8
9/// Type of a function that can transform a decompressed, unfiltered row (the
10/// 1st argument) into the final pixels (the 2nd argument), optionally using
11/// image metadata (e.g. PLTE data can be accessed using the 3rd argument).
12///
13/// TODO: If some precomputed state is needed (e.g. to make `expand_paletted...`
14/// faster) then consider changing this into `Box<dyn Fn(...)>`.
15pub type TransformFn = Box<dyn Fn(&[u8], &mut [u8], &Info) + Send + Sync>;
16
17/// Returns a transformation function that should be applied to image rows based
18/// on 1) decoded image metadata (`info`) and 2) the transformations requested
19/// by the crate client (`transform`).
20pub fn create_transform_fn(
21    info: &Info,
22    transform: Transformations,
23) -> Result<TransformFn, DecodingError> {
24    let color_type = info.color_type;
25    let bit_depth = info.bit_depth as u8;
26    let trns = info.trns.is_some() || transform.contains(Transformations::ALPHA);
27    let expand =
28        transform.contains(Transformations::EXPAND) || transform.contains(Transformations::ALPHA);
29    let strip16 = bit_depth == 16 && transform.contains(Transformations::STRIP_16);
30    match color_type {
31        ColorType::Indexed if expand => {
32            if info.palette.is_none() {
33                Err(DecodingError::Format(
34                    FormatErrorInner::PaletteRequired.into(),
35                ))
36            } else if let BitDepth::Sixteen = info.bit_depth {
37                // This should have been caught earlier but let's check again. Can't hurt.
38                Err(DecodingError::Format(
39                    FormatErrorInner::InvalidColorBitDepth {
40                        color_type: ColorType::Indexed,
41                        bit_depth: BitDepth::Sixteen,
42                    }
43                    .into(),
44                ))
45            } else {
46                Ok(if trns {
47                    palette::create_expansion_into_rgba8(info)
48                } else {
49                    palette::create_expansion_into_rgb8(info)
50                })
51            }
52        }
53        ColorType::Grayscale | ColorType::GrayscaleAlpha if bit_depth < 8 && expand => {
54            Ok(Box::new(if trns {
55                expand_gray_u8_with_trns
56            } else {
57                expand_gray_u8
58            }))
59        }
60        ColorType::Grayscale | ColorType::Rgb if expand && trns => {
61            Ok(Box::new(if bit_depth == 8 {
62                expand_trns_line
63            } else if strip16 {
64                expand_trns_and_strip_line16
65            } else {
66                assert_eq!(bit_depth, 16);
67                expand_trns_line16
68            }))
69        }
70        ColorType::Grayscale | ColorType::GrayscaleAlpha | ColorType::Rgb | ColorType::Rgba
71            if strip16 =>
72        {
73            Ok(Box::new(transform_row_strip16))
74        }
75        _ => Ok(Box::new(copy_row)),
76    }
77}
78
79fn copy_row(row: &[u8], output_buffer: &mut [u8], _: &Info) {
80    output_buffer.copy_from_slice(row);
81}
82
83fn transform_row_strip16(row: &[u8], output_buffer: &mut [u8], _: &Info) {
84    for i in 0..row.len() / 2 {
85        output_buffer[i] = row[2 * i];
86    }
87}
88
89#[inline(always)]
90fn unpack_bits<F>(input: &[u8], output: &mut [u8], channels: usize, bit_depth: u8, func: F)
91where
92    F: Fn(u8, &mut [u8]),
93{
94    // Only [1, 2, 4, 8] are valid bit depths
95    assert!(matches!(bit_depth, 1 | 2 | 4 | 8));
96    // Check that `input` is capable of producing a buffer as long as `output`:
97    // number of shift lookups per bit depth * channels * input length
98    assert!((8 / bit_depth as usize * channels).saturating_mul(input.len()) >= output.len());
99
100    let mut buf_chunks = output.chunks_exact_mut(channels);
101    let mut iter = input.iter();
102
103    // `shift` iterates through the corresponding bit depth sequence:
104    // 1 => &[7, 6, 5, 4, 3, 2, 1, 0],
105    // 2 => &[6, 4, 2, 0],
106    // 4 => &[4, 0],
107    // 8 => &[0],
108    //
109    // `(0..8).step_by(bit_depth.into()).rev()` doesn't always optimize well so
110    // shifts are calculated instead. (2023-08, Rust 1.71)
111
112    if bit_depth == 8 {
113        for (&curr, chunk) in iter.zip(&mut buf_chunks) {
114            func(curr, chunk);
115        }
116    } else {
117        let mask = ((1u16 << bit_depth) - 1) as u8;
118
119        // These variables are initialized in the loop
120        let mut shift = -1;
121        let mut curr = 0;
122
123        for chunk in buf_chunks {
124            if shift < 0 {
125                shift = 8 - bit_depth as i32;
126                curr = *iter.next().expect("input for unpack bits is not empty");
127            }
128
129            let pixel = (curr >> shift) & mask;
130            func(pixel, chunk);
131
132            shift -= bit_depth as i32;
133        }
134    }
135}
136
137fn expand_trns_line(input: &[u8], output: &mut [u8], info: &Info) {
138    let channels = info.color_type.samples();
139    let trns = info.trns.as_deref();
140    for (input, output) in input
141        .chunks_exact(channels)
142        .zip(output.chunks_exact_mut(channels + 1))
143    {
144        output[..channels].copy_from_slice(input);
145        output[channels] = if Some(input) == trns { 0 } else { 0xFF };
146    }
147}
148
149fn expand_trns_line16(input: &[u8], output: &mut [u8], info: &Info) {
150    let channels = info.color_type.samples();
151    let trns = info.trns.as_deref();
152    for (input, output) in input
153        .chunks_exact(channels * 2)
154        .zip(output.chunks_exact_mut(channels * 2 + 2))
155    {
156        output[..channels * 2].copy_from_slice(input);
157        if Some(input) == trns {
158            output[channels * 2] = 0;
159            output[channels * 2 + 1] = 0
160        } else {
161            output[channels * 2] = 0xFF;
162            output[channels * 2 + 1] = 0xFF
163        };
164    }
165}
166
167fn expand_trns_and_strip_line16(input: &[u8], output: &mut [u8], info: &Info) {
168    let channels = info.color_type.samples();
169    let trns = info.trns.as_deref();
170    for (input, output) in input
171        .chunks_exact(channels * 2)
172        .zip(output.chunks_exact_mut(channels + 1))
173    {
174        for i in 0..channels {
175            output[i] = input[i * 2];
176        }
177        output[channels] = if Some(input) == trns { 0 } else { 0xFF };
178    }
179}
180
181fn expand_gray_u8(row: &[u8], buffer: &mut [u8], info: &Info) {
182    let scaling_factor = (255) / ((1u16 << info.bit_depth as u8) - 1) as u8;
183    unpack_bits(row, buffer, 1, info.bit_depth as u8, |val, chunk| {
184        chunk[0] = val * scaling_factor
185    });
186}
187
188fn expand_gray_u8_with_trns(row: &[u8], buffer: &mut [u8], info: &Info) {
189    let scaling_factor = (255) / ((1u16 << info.bit_depth as u8) - 1) as u8;
190    let trns = info.trns.as_deref();
191    unpack_bits(row, buffer, 2, info.bit_depth as u8, |pixel, chunk| {
192        chunk[1] = if let Some(trns) = trns {
193            if pixel == trns[0] {
194                0
195            } else {
196                0xFF
197            }
198        } else {
199            0xFF
200        };
201        chunk[0] = pixel * scaling_factor
202    });
203}