png/
filter.rs

1use core::convert::TryInto;
2
3use crate::{common::BytesPerPixel, Compression};
4
5/// The byte level filter applied to scanlines to prepare them for compression.
6///
7/// Compression in general benefits from repetitive data. The filter is a content-aware method of
8/// compressing the range of occurring byte values to help the compression algorithm. Note that
9/// this does not operate on pixels but on raw bytes of a scanline.
10///
11/// Details on how each filter works can be found in the [PNG Book](http://www.libpng.org/pub/png/book/chapter09.html).
12///
13/// The default filter is `Adaptive`, which uses heuristics to select the best filter for every row.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[non_exhaustive]
16pub enum Filter {
17    NoFilter,
18    Sub,
19    Up,
20    Avg,
21    Paeth,
22    Adaptive,
23}
24
25impl Default for Filter {
26    fn default() -> Self {
27        Filter::Adaptive
28    }
29}
30
31impl From<RowFilter> for Filter {
32    fn from(value: RowFilter) -> Self {
33        match value {
34            RowFilter::NoFilter => Filter::NoFilter,
35            RowFilter::Sub => Filter::Sub,
36            RowFilter::Up => Filter::Up,
37            RowFilter::Avg => Filter::Avg,
38            RowFilter::Paeth => Filter::Paeth,
39        }
40    }
41}
42
43impl Filter {
44    pub(crate) fn from_simple(compression: Compression) -> Self {
45        match compression {
46            Compression::NoCompression => Filter::NoFilter, // with no DEFLATE filtering would only waste time
47            Compression::Fastest => Filter::Up, // pairs well with FdeflateUltraFast, producing much smaller files while being very fast
48            Compression::Fast => Filter::Adaptive,
49            Compression::Balanced => Filter::Adaptive,
50            Compression::High => Filter::Adaptive,
51        }
52    }
53}
54
55/// Unlike the public [Filter], does not include the "Adaptive" option
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57#[repr(u8)]
58pub(crate) enum RowFilter {
59    NoFilter = 0,
60    Sub = 1,
61    Up = 2,
62    Avg = 3,
63    Paeth = 4,
64}
65
66impl Default for RowFilter {
67    fn default() -> Self {
68        RowFilter::Up
69    }
70}
71
72impl RowFilter {
73    pub fn from_u8(n: u8) -> Option<Self> {
74        match n {
75            0 => Some(Self::NoFilter),
76            1 => Some(Self::Sub),
77            2 => Some(Self::Up),
78            3 => Some(Self::Avg),
79            4 => Some(Self::Paeth),
80            _ => None,
81        }
82    }
83
84    pub fn from_method(strat: Filter) -> Option<Self> {
85        match strat {
86            Filter::NoFilter => Some(Self::NoFilter),
87            Filter::Sub => Some(Self::Sub),
88            Filter::Up => Some(Self::Up),
89            Filter::Avg => Some(Self::Avg),
90            Filter::Paeth => Some(Self::Paeth),
91            Filter::Adaptive => None,
92        }
93    }
94}
95
96fn filter_paeth(a: u8, b: u8, c: u8) -> u8 {
97    // On ARM this algorithm performs much better than the one above adapted from stb,
98    // and this is the better-studied algorithm we've always used here,
99    // so we default to it on all non-x86 platforms.
100    let pa = (i16::from(b) - i16::from(c)).abs();
101    let pb = (i16::from(a) - i16::from(c)).abs();
102    let pc = ((i16::from(a) - i16::from(c)) + (i16::from(b) - i16::from(c))).abs();
103
104    let mut out = a;
105    let mut min = pa;
106
107    if pb < min {
108        min = pb;
109        out = b;
110    }
111    if pc < min {
112        out = c;
113    }
114
115    out
116}
117
118fn filter_paeth_stbi(a: u8, b: u8, c: u8) -> u8 {
119    // Decoding optimizes better with this algorithm than with `filter_paeth`
120    //
121    // This formulation looks very different from the reference in the PNG spec, but is
122    // actually equivalent and has favorable data dependencies and admits straightforward
123    // generation of branch-free code, which helps performance significantly.
124    //
125    // Adapted from public domain PNG implementation:
126    // https://github.com/nothings/stb/blob/5c205738c191bcb0abc65c4febfa9bd25ff35234/stb_image.h#L4657-L4668
127    let thresh = i16::from(c) * 3 - (i16::from(a) + i16::from(b));
128    let lo = a.min(b);
129    let hi = a.max(b);
130    let t0 = if hi as i16 <= thresh { lo } else { c };
131    let t1 = if thresh <= lo as i16 { hi } else { t0 };
132    t1
133}
134
135fn filter_paeth_fpnge(a: u8, b: u8, c: u8) -> u8 {
136    // This is an optimized version of the paeth filter from the PNG specification, proposed by
137    // Luca Versari for [FPNGE](https://www.lucaversari.it/FJXL_and_FPNGE.pdf). It operates
138    // entirely on unsigned 8-bit quantities, making it more conducive to vectorization.
139    //
140    //     p = a + b - c
141    //     pa = |p - a| = |a + b - c - a| = |b - c| = max(b, c) - min(b, c)
142    //     pb = |p - b| = |a + b - c - b| = |a - c| = max(a, c) - min(a, c)
143    //     pc = |p - c| = |a + b - c - c| = |(b - c) + (a - c)| = ...
144    //
145    // Further optimizing the calculation of `pc` a bit tricker. However, notice that:
146    //
147    //        a > c && b > c
148    //    ==> (a - c) > 0 && (b - c) > 0
149    //    ==> pc > (a - c) && pc > (b - c)
150    //    ==> pc > |a - c| && pc > |b - c|
151    //    ==> pc > pb && pc > pa
152    //
153    // Meaning that if `c` is smaller than `a` and `b`, the value of `pc` is irrelevant. Similar
154    // reasoning applies if `c` is larger than the other two inputs. Assuming that `c >= b` and
155    // `c <= b` or vice versa:
156    //
157    //     pc = ||b - c| - |a - c|| =  |pa - pb| = max(pa, pb) - min(pa, pb)
158    //
159    let pa = b.max(c) - c.min(b);
160    let pb = a.max(c) - c.min(a);
161    let pc = if (a < c) == (c < b) {
162        pa.max(pb) - pa.min(pb)
163    } else {
164        255
165    };
166
167    if pa <= pb && pa <= pc {
168        a
169    } else if pb <= pc {
170        b
171    } else {
172        c
173    }
174}
175
176pub(crate) fn unfilter(
177    mut filter: RowFilter,
178    tbpp: BytesPerPixel,
179    previous: &[u8],
180    current: &mut [u8],
181) {
182    use self::RowFilter::*;
183
184    // If the previous row is empty, then treat it as if it were filled with zeros.
185    if previous.is_empty() {
186        if filter == Paeth {
187            filter = Sub;
188        } else if filter == Up {
189            filter = NoFilter;
190        }
191    }
192
193    // Auto-vectorization notes
194    // ========================
195    //
196    // [2023/01 @okaneco] - Notes on optimizing decoding filters
197    //
198    // Links:
199    // [PR]: https://github.com/image-rs/image-png/pull/382
200    // [SWAR]: http://aggregate.org/SWAR/over.html
201    // [AVG]: http://aggregate.org/MAGIC/#Average%20of%20Integers
202    //
203    // #382 heavily refactored and optimized the following filters making the
204    // implementation nonobvious. These comments function as a summary of that
205    // PR with an explanation of the choices made below.
206    //
207    // #382 originally started with trying to optimize using a technique called
208    // SWAR, SIMD Within a Register. SWAR uses regular integer types like `u32`
209    // and `u64` as SIMD registers to perform vertical operations in parallel,
210    // usually involving bit-twiddling. This allowed each `BytesPerPixel` (bpp)
211    // pixel to be decoded in parallel: 3bpp and 4bpp in a `u32`, 6bpp and 8pp
212    // in a `u64`. The `Sub` filter looked like the following code block, `Avg`
213    // was similar but used a bitwise average method from [AVG]:
214    // ```
215    // // See "Unpartitioned Operations With Correction Code" from [SWAR]
216    // fn swar_add_u32(x: u32, y: u32) -> u32 {
217    //     // 7-bit addition so there's no carry over the most significant bit
218    //     let n = (x & 0x7f7f7f7f) + (y & 0x7f7f7f7f); // 0x7F = 0b_0111_1111
219    //     // 1-bit parity/XOR addition to fill in the missing MSB
220    //     n ^ (x ^ y) & 0x80808080                     // 0x80 = 0b_1000_0000
221    // }
222    //
223    // let mut prev =
224    //     u32::from_ne_bytes([current[0], current[1], current[2], current[3]]);
225    // for chunk in current[4..].chunks_exact_mut(4) {
226    //     let cur = u32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
227    //     let new_chunk = swar_add_u32(cur, prev);
228    //     chunk.copy_from_slice(&new_chunk.to_ne_bytes());
229    //     prev = new_chunk;
230    // }
231    // ```
232    // While this provided a measurable increase, @fintelia found that this idea
233    // could be taken even further by unrolling the chunks component-wise and
234    // avoiding unnecessary byte-shuffling by using byte arrays instead of
235    // `u32::from|to_ne_bytes`. The bitwise operations were no longer necessary
236    // so they were reverted to their obvious arithmetic equivalent. Lastly,
237    // `TryInto` was used instead of `copy_from_slice`. The `Sub` code now
238    // looked like this (with asserts to remove `0..bpp` bounds checks):
239    // ```
240    // assert!(len > 3);
241    // let mut prev = [current[0], current[1], current[2], current[3]];
242    // for chunk in current[4..].chunks_exact_mut(4) {
243    //     let new_chunk = [
244    //         chunk[0].wrapping_add(prev[0]),
245    //         chunk[1].wrapping_add(prev[1]),
246    //         chunk[2].wrapping_add(prev[2]),
247    //         chunk[3].wrapping_add(prev[3]),
248    //     ];
249    //     *TryInto::<&mut [u8; 4]>::try_into(chunk).unwrap() = new_chunk;
250    //     prev = new_chunk;
251    // }
252    // ```
253    // The compiler was able to optimize the code to be even faster and this
254    // method even sped up Paeth filtering! Assertions were experimentally
255    // added within loop bodies which produced better instructions but no
256    // difference in speed. Finally, the code was refactored to remove manual
257    // slicing and start the previous pixel chunks with arrays of `[0; N]`.
258    // ```
259    // let mut prev = [0; 4];
260    // for chunk in current.chunks_exact_mut(4) {
261    //     let new_chunk = [
262    //         chunk[0].wrapping_add(prev[0]),
263    //         chunk[1].wrapping_add(prev[1]),
264    //         chunk[2].wrapping_add(prev[2]),
265    //         chunk[3].wrapping_add(prev[3]),
266    //     ];
267    //     *TryInto::<&mut [u8; 4]>::try_into(chunk).unwrap() = new_chunk;
268    //     prev = new_chunk;
269    // }
270    // ```
271    // While we're not manually bit-twiddling anymore, a possible takeaway from
272    // this is to "think in SWAR" when dealing with small byte arrays. Unrolling
273    // array operations and performing them component-wise may unlock previously
274    // unavailable optimizations from the compiler, even when using the
275    // `chunks_exact` methods for their potential auto-vectorization benefits.
276    //
277    // `std::simd` notes
278    // =================
279    //
280    // In the past we have experimented with `std::simd` for unfiltering.  This
281    // experiment was removed in https://github.com/image-rs/image-png/pull/585
282    // because:
283    //
284    // * The crate's microbenchmarks showed that `std::simd` didn't have a
285    //   significant advantage over auto-vectorization for most filters, except
286    //   for Paeth unfiltering - see
287    //   https://github.com/image-rs/image-png/pull/414#issuecomment-1736655668
288    // * In the crate's microbenchmarks `std::simd` seemed to help with Paeth
289    //   unfiltering only on x86/x64, with mixed results on ARM - see
290    //   https://github.com/image-rs/image-png/pull/539#issuecomment-2512748043
291    // * In Chromium end-to-end microbenchmarks `std::simd` either didn't help
292    //   or resulted in a small regression (as measured on x64).  See
293    //   https://crrev.com/c/6090592.
294    // * Field trial data from some "real world" scenarios shows that
295    //   performance can be quite good without relying on `std::simd` - see
296    //   https://github.com/image-rs/image-png/discussions/562#discussioncomment-13303307
297    match filter {
298        NoFilter => {}
299        Sub => match tbpp {
300            BytesPerPixel::One => {
301                current.iter_mut().reduce(|&mut prev, curr| {
302                    *curr = curr.wrapping_add(prev);
303                    curr
304                });
305            }
306            BytesPerPixel::Two => {
307                let mut prev = [0; 2];
308                for chunk in current.chunks_exact_mut(2) {
309                    let new_chunk = [
310                        chunk[0].wrapping_add(prev[0]),
311                        chunk[1].wrapping_add(prev[1]),
312                    ];
313                    *TryInto::<&mut [u8; 2]>::try_into(chunk).unwrap() = new_chunk;
314                    prev = new_chunk;
315                }
316            }
317            BytesPerPixel::Three => {
318                let mut prev = [0; 3];
319                for chunk in current.chunks_exact_mut(3) {
320                    let new_chunk = [
321                        chunk[0].wrapping_add(prev[0]),
322                        chunk[1].wrapping_add(prev[1]),
323                        chunk[2].wrapping_add(prev[2]),
324                    ];
325                    *TryInto::<&mut [u8; 3]>::try_into(chunk).unwrap() = new_chunk;
326                    prev = new_chunk;
327                }
328            }
329            BytesPerPixel::Four => {
330                let mut prev = [0; 4];
331                for chunk in current.chunks_exact_mut(4) {
332                    let new_chunk = [
333                        chunk[0].wrapping_add(prev[0]),
334                        chunk[1].wrapping_add(prev[1]),
335                        chunk[2].wrapping_add(prev[2]),
336                        chunk[3].wrapping_add(prev[3]),
337                    ];
338                    *TryInto::<&mut [u8; 4]>::try_into(chunk).unwrap() = new_chunk;
339                    prev = new_chunk;
340                }
341            }
342            BytesPerPixel::Six => {
343                let mut prev = [0; 6];
344                for chunk in current.chunks_exact_mut(6) {
345                    let new_chunk = [
346                        chunk[0].wrapping_add(prev[0]),
347                        chunk[1].wrapping_add(prev[1]),
348                        chunk[2].wrapping_add(prev[2]),
349                        chunk[3].wrapping_add(prev[3]),
350                        chunk[4].wrapping_add(prev[4]),
351                        chunk[5].wrapping_add(prev[5]),
352                    ];
353                    *TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
354                    prev = new_chunk;
355                }
356            }
357            BytesPerPixel::Eight => {
358                let mut prev = [0; 8];
359                for chunk in current.chunks_exact_mut(8) {
360                    let new_chunk = [
361                        chunk[0].wrapping_add(prev[0]),
362                        chunk[1].wrapping_add(prev[1]),
363                        chunk[2].wrapping_add(prev[2]),
364                        chunk[3].wrapping_add(prev[3]),
365                        chunk[4].wrapping_add(prev[4]),
366                        chunk[5].wrapping_add(prev[5]),
367                        chunk[6].wrapping_add(prev[6]),
368                        chunk[7].wrapping_add(prev[7]),
369                    ];
370                    *TryInto::<&mut [u8; 8]>::try_into(chunk).unwrap() = new_chunk;
371                    prev = new_chunk;
372                }
373            }
374        },
375        Up => {
376            for (curr, &above) in current.iter_mut().zip(previous) {
377                *curr = curr.wrapping_add(above);
378            }
379        }
380        Avg if previous.is_empty() => match tbpp {
381            BytesPerPixel::One => {
382                current.iter_mut().reduce(|&mut prev, curr| {
383                    *curr = curr.wrapping_add(prev / 2);
384                    curr
385                });
386            }
387            BytesPerPixel::Two => {
388                let mut prev = [0; 2];
389                for chunk in current.chunks_exact_mut(2) {
390                    let new_chunk = [
391                        chunk[0].wrapping_add(prev[0] / 2),
392                        chunk[1].wrapping_add(prev[1] / 2),
393                    ];
394                    *TryInto::<&mut [u8; 2]>::try_into(chunk).unwrap() = new_chunk;
395                    prev = new_chunk;
396                }
397            }
398            BytesPerPixel::Three => {
399                let mut prev = [0; 3];
400                for chunk in current.chunks_exact_mut(3) {
401                    let new_chunk = [
402                        chunk[0].wrapping_add(prev[0] / 2),
403                        chunk[1].wrapping_add(prev[1] / 2),
404                        chunk[2].wrapping_add(prev[2] / 2),
405                    ];
406                    *TryInto::<&mut [u8; 3]>::try_into(chunk).unwrap() = new_chunk;
407                    prev = new_chunk;
408                }
409            }
410            BytesPerPixel::Four => {
411                let mut prev = [0; 4];
412                for chunk in current.chunks_exact_mut(4) {
413                    let new_chunk = [
414                        chunk[0].wrapping_add(prev[0] / 2),
415                        chunk[1].wrapping_add(prev[1] / 2),
416                        chunk[2].wrapping_add(prev[2] / 2),
417                        chunk[3].wrapping_add(prev[3] / 2),
418                    ];
419                    *TryInto::<&mut [u8; 4]>::try_into(chunk).unwrap() = new_chunk;
420                    prev = new_chunk;
421                }
422            }
423            BytesPerPixel::Six => {
424                let mut prev = [0; 6];
425                for chunk in current.chunks_exact_mut(6) {
426                    let new_chunk = [
427                        chunk[0].wrapping_add(prev[0] / 2),
428                        chunk[1].wrapping_add(prev[1] / 2),
429                        chunk[2].wrapping_add(prev[2] / 2),
430                        chunk[3].wrapping_add(prev[3] / 2),
431                        chunk[4].wrapping_add(prev[4] / 2),
432                        chunk[5].wrapping_add(prev[5] / 2),
433                    ];
434                    *TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
435                    prev = new_chunk;
436                }
437            }
438            BytesPerPixel::Eight => {
439                let mut prev = [0; 8];
440                for chunk in current.chunks_exact_mut(8) {
441                    let new_chunk = [
442                        chunk[0].wrapping_add(prev[0] / 2),
443                        chunk[1].wrapping_add(prev[1] / 2),
444                        chunk[2].wrapping_add(prev[2] / 2),
445                        chunk[3].wrapping_add(prev[3] / 2),
446                        chunk[4].wrapping_add(prev[4] / 2),
447                        chunk[5].wrapping_add(prev[5] / 2),
448                        chunk[6].wrapping_add(prev[6] / 2),
449                        chunk[7].wrapping_add(prev[7] / 2),
450                    ];
451                    *TryInto::<&mut [u8; 8]>::try_into(chunk).unwrap() = new_chunk;
452                    prev = new_chunk;
453                }
454            }
455        },
456        Avg => match tbpp {
457            BytesPerPixel::One => {
458                let mut lprev = [0; 1];
459                for (chunk, above) in current.chunks_exact_mut(1).zip(previous.chunks_exact(1)) {
460                    let new_chunk =
461                        [chunk[0].wrapping_add(((above[0] as u16 + lprev[0] as u16) / 2) as u8)];
462                    *TryInto::<&mut [u8; 1]>::try_into(chunk).unwrap() = new_chunk;
463                    lprev = new_chunk;
464                }
465            }
466            BytesPerPixel::Two => {
467                let mut lprev = [0; 2];
468                for (chunk, above) in current.chunks_exact_mut(2).zip(previous.chunks_exact(2)) {
469                    let new_chunk = [
470                        chunk[0].wrapping_add(((above[0] as u16 + lprev[0] as u16) / 2) as u8),
471                        chunk[1].wrapping_add(((above[1] as u16 + lprev[1] as u16) / 2) as u8),
472                    ];
473                    *TryInto::<&mut [u8; 2]>::try_into(chunk).unwrap() = new_chunk;
474                    lprev = new_chunk;
475                }
476            }
477            BytesPerPixel::Three => {
478                let mut lprev = [0; 3];
479                for (chunk, above) in current.chunks_exact_mut(3).zip(previous.chunks_exact(3)) {
480                    let new_chunk = [
481                        chunk[0].wrapping_add(((above[0] as u16 + lprev[0] as u16) / 2) as u8),
482                        chunk[1].wrapping_add(((above[1] as u16 + lprev[1] as u16) / 2) as u8),
483                        chunk[2].wrapping_add(((above[2] as u16 + lprev[2] as u16) / 2) as u8),
484                    ];
485                    *TryInto::<&mut [u8; 3]>::try_into(chunk).unwrap() = new_chunk;
486                    lprev = new_chunk;
487                }
488            }
489            BytesPerPixel::Four => {
490                let mut lprev = [0; 4];
491                for (chunk, above) in current.chunks_exact_mut(4).zip(previous.chunks_exact(4)) {
492                    let new_chunk = [
493                        chunk[0].wrapping_add(((above[0] as u16 + lprev[0] as u16) / 2) as u8),
494                        chunk[1].wrapping_add(((above[1] as u16 + lprev[1] as u16) / 2) as u8),
495                        chunk[2].wrapping_add(((above[2] as u16 + lprev[2] as u16) / 2) as u8),
496                        chunk[3].wrapping_add(((above[3] as u16 + lprev[3] as u16) / 2) as u8),
497                    ];
498                    *TryInto::<&mut [u8; 4]>::try_into(chunk).unwrap() = new_chunk;
499                    lprev = new_chunk;
500                }
501            }
502            BytesPerPixel::Six => {
503                let mut lprev = [0; 6];
504                for (chunk, above) in current.chunks_exact_mut(6).zip(previous.chunks_exact(6)) {
505                    let new_chunk = [
506                        chunk[0].wrapping_add(((above[0] as u16 + lprev[0] as u16) / 2) as u8),
507                        chunk[1].wrapping_add(((above[1] as u16 + lprev[1] as u16) / 2) as u8),
508                        chunk[2].wrapping_add(((above[2] as u16 + lprev[2] as u16) / 2) as u8),
509                        chunk[3].wrapping_add(((above[3] as u16 + lprev[3] as u16) / 2) as u8),
510                        chunk[4].wrapping_add(((above[4] as u16 + lprev[4] as u16) / 2) as u8),
511                        chunk[5].wrapping_add(((above[5] as u16 + lprev[5] as u16) / 2) as u8),
512                    ];
513                    *TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
514                    lprev = new_chunk;
515                }
516            }
517            BytesPerPixel::Eight => {
518                let mut lprev = [0; 8];
519                for (chunk, above) in current.chunks_exact_mut(8).zip(previous.chunks_exact(8)) {
520                    let new_chunk = [
521                        chunk[0].wrapping_add(((above[0] as u16 + lprev[0] as u16) / 2) as u8),
522                        chunk[1].wrapping_add(((above[1] as u16 + lprev[1] as u16) / 2) as u8),
523                        chunk[2].wrapping_add(((above[2] as u16 + lprev[2] as u16) / 2) as u8),
524                        chunk[3].wrapping_add(((above[3] as u16 + lprev[3] as u16) / 2) as u8),
525                        chunk[4].wrapping_add(((above[4] as u16 + lprev[4] as u16) / 2) as u8),
526                        chunk[5].wrapping_add(((above[5] as u16 + lprev[5] as u16) / 2) as u8),
527                        chunk[6].wrapping_add(((above[6] as u16 + lprev[6] as u16) / 2) as u8),
528                        chunk[7].wrapping_add(((above[7] as u16 + lprev[7] as u16) / 2) as u8),
529                    ];
530                    *TryInto::<&mut [u8; 8]>::try_into(chunk).unwrap() = new_chunk;
531                    lprev = new_chunk;
532                }
533            }
534        },
535        #[allow(unreachable_code)]
536        Paeth => {
537            // Select the fastest Paeth filter implementation based on the target architecture.
538            let filter_paeth_decode = if cfg!(target_arch = "x86_64") {
539                filter_paeth_stbi
540            } else {
541                filter_paeth
542            };
543
544            // Paeth filter pixels:
545            // C B D
546            // A X
547            match tbpp {
548                BytesPerPixel::One => {
549                    let mut a_bpp = [0; 1];
550                    let mut c_bpp = [0; 1];
551                    for (chunk, b_bpp) in current.chunks_exact_mut(1).zip(previous.chunks_exact(1))
552                    {
553                        let new_chunk = [chunk[0]
554                            .wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0]))];
555                        *TryInto::<&mut [u8; 1]>::try_into(chunk).unwrap() = new_chunk;
556                        a_bpp = new_chunk;
557                        c_bpp = b_bpp.try_into().unwrap();
558                    }
559                }
560                BytesPerPixel::Two => {
561                    let mut a_bpp = [0; 2];
562                    let mut c_bpp = [0; 2];
563                    for (chunk, b_bpp) in current.chunks_exact_mut(2).zip(previous.chunks_exact(2))
564                    {
565                        let new_chunk = [
566                            chunk[0]
567                                .wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])),
568                            chunk[1]
569                                .wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])),
570                        ];
571                        *TryInto::<&mut [u8; 2]>::try_into(chunk).unwrap() = new_chunk;
572                        a_bpp = new_chunk;
573                        c_bpp = b_bpp.try_into().unwrap();
574                    }
575                }
576                BytesPerPixel::Three => {
577                    let mut a_bpp = [0; 3];
578                    let mut c_bpp = [0; 3];
579
580                    let mut previous = &previous[..previous.len() / 3 * 3];
581                    let current_len = current.len();
582                    let mut current = &mut current[..current_len / 3 * 3];
583
584                    while let ([c0, c1, c2, c_rest @ ..], [p0, p1, p2, p_rest @ ..]) =
585                        (current, previous)
586                    {
587                        current = c_rest;
588                        previous = p_rest;
589
590                        *c0 = c0.wrapping_add(filter_paeth_decode(a_bpp[0], *p0, c_bpp[0]));
591                        *c1 = c1.wrapping_add(filter_paeth_decode(a_bpp[1], *p1, c_bpp[1]));
592                        *c2 = c2.wrapping_add(filter_paeth_decode(a_bpp[2], *p2, c_bpp[2]));
593
594                        a_bpp = [*c0, *c1, *c2];
595                        c_bpp = [*p0, *p1, *p2];
596                    }
597                }
598                BytesPerPixel::Four => {
599                    // Using the `simd` module here has no effect on Linux
600                    // and appears to regress performance on Windows, so we don't use it here.
601                    // See https://github.com/image-rs/image-png/issues/567
602
603                    let mut a_bpp = [0; 4];
604                    let mut c_bpp = [0; 4];
605
606                    let mut previous = &previous[..previous.len() & !3];
607                    let current_len = current.len();
608                    let mut current = &mut current[..current_len & !3];
609
610                    while let ([c0, c1, c2, c3, c_rest @ ..], [p0, p1, p2, p3, p_rest @ ..]) =
611                        (current, previous)
612                    {
613                        current = c_rest;
614                        previous = p_rest;
615
616                        *c0 = c0.wrapping_add(filter_paeth_decode(a_bpp[0], *p0, c_bpp[0]));
617                        *c1 = c1.wrapping_add(filter_paeth_decode(a_bpp[1], *p1, c_bpp[1]));
618                        *c2 = c2.wrapping_add(filter_paeth_decode(a_bpp[2], *p2, c_bpp[2]));
619                        *c3 = c3.wrapping_add(filter_paeth_decode(a_bpp[3], *p3, c_bpp[3]));
620
621                        a_bpp = [*c0, *c1, *c2, *c3];
622                        c_bpp = [*p0, *p1, *p2, *p3];
623                    }
624                }
625                BytesPerPixel::Six => {
626                    let mut a_bpp = [0; 6];
627                    let mut c_bpp = [0; 6];
628                    for (chunk, b_bpp) in current.chunks_exact_mut(6).zip(previous.chunks_exact(6))
629                    {
630                        let new_chunk = [
631                            chunk[0]
632                                .wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])),
633                            chunk[1]
634                                .wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])),
635                            chunk[2]
636                                .wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])),
637                            chunk[3]
638                                .wrapping_add(filter_paeth_decode(a_bpp[3], b_bpp[3], c_bpp[3])),
639                            chunk[4]
640                                .wrapping_add(filter_paeth_decode(a_bpp[4], b_bpp[4], c_bpp[4])),
641                            chunk[5]
642                                .wrapping_add(filter_paeth_decode(a_bpp[5], b_bpp[5], c_bpp[5])),
643                        ];
644                        *TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
645                        a_bpp = new_chunk;
646                        c_bpp = b_bpp.try_into().unwrap();
647                    }
648                }
649                BytesPerPixel::Eight => {
650                    let mut a_bpp = [0; 8];
651                    let mut c_bpp = [0; 8];
652                    for (chunk, b_bpp) in current.chunks_exact_mut(8).zip(previous.chunks_exact(8))
653                    {
654                        let new_chunk = [
655                            chunk[0]
656                                .wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])),
657                            chunk[1]
658                                .wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])),
659                            chunk[2]
660                                .wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])),
661                            chunk[3]
662                                .wrapping_add(filter_paeth_decode(a_bpp[3], b_bpp[3], c_bpp[3])),
663                            chunk[4]
664                                .wrapping_add(filter_paeth_decode(a_bpp[4], b_bpp[4], c_bpp[4])),
665                            chunk[5]
666                                .wrapping_add(filter_paeth_decode(a_bpp[5], b_bpp[5], c_bpp[5])),
667                            chunk[6]
668                                .wrapping_add(filter_paeth_decode(a_bpp[6], b_bpp[6], c_bpp[6])),
669                            chunk[7]
670                                .wrapping_add(filter_paeth_decode(a_bpp[7], b_bpp[7], c_bpp[7])),
671                        ];
672                        *TryInto::<&mut [u8; 8]>::try_into(chunk).unwrap() = new_chunk;
673                        a_bpp = new_chunk;
674                        c_bpp = b_bpp.try_into().unwrap();
675                    }
676                }
677            }
678        }
679    }
680}
681
682fn filter_internal(
683    method: RowFilter,
684    bpp: usize,
685    len: usize,
686    previous: &[u8],
687    current: &[u8],
688    output: &mut [u8],
689) -> RowFilter {
690    use self::RowFilter::*;
691
692    // This value was chosen experimentally based on what achieved the best performance. The
693    // Rust compiler does auto-vectorization, and 32-bytes per loop iteration seems to enable
694    // the fastest code when doing so.
695    const CHUNK_SIZE: usize = 32;
696
697    match method {
698        NoFilter => {
699            output.copy_from_slice(current);
700            NoFilter
701        }
702        Sub => {
703            let mut out_chunks = output[bpp..].chunks_exact_mut(CHUNK_SIZE);
704            let mut cur_chunks = current[bpp..].chunks_exact(CHUNK_SIZE);
705            let mut prev_chunks = current[..len - bpp].chunks_exact(CHUNK_SIZE);
706
707            for ((out, cur), prev) in (&mut out_chunks).zip(&mut cur_chunks).zip(&mut prev_chunks) {
708                for i in 0..CHUNK_SIZE {
709                    out[i] = cur[i].wrapping_sub(prev[i]);
710                }
711            }
712
713            for ((out, cur), &prev) in out_chunks
714                .into_remainder()
715                .iter_mut()
716                .zip(cur_chunks.remainder())
717                .zip(prev_chunks.remainder())
718            {
719                *out = cur.wrapping_sub(prev);
720            }
721
722            output[..bpp].copy_from_slice(&current[..bpp]);
723            Sub
724        }
725        Up => {
726            let mut out_chunks = output.chunks_exact_mut(CHUNK_SIZE);
727            let mut cur_chunks = current.chunks_exact(CHUNK_SIZE);
728            let mut prev_chunks = previous.chunks_exact(CHUNK_SIZE);
729
730            for ((out, cur), prev) in (&mut out_chunks).zip(&mut cur_chunks).zip(&mut prev_chunks) {
731                for i in 0..CHUNK_SIZE {
732                    out[i] = cur[i].wrapping_sub(prev[i]);
733                }
734            }
735
736            for ((out, cur), &prev) in out_chunks
737                .into_remainder()
738                .iter_mut()
739                .zip(cur_chunks.remainder())
740                .zip(prev_chunks.remainder())
741            {
742                *out = cur.wrapping_sub(prev);
743            }
744            Up
745        }
746        Avg => {
747            let mut out_chunks = output[bpp..].chunks_exact_mut(CHUNK_SIZE);
748            let mut cur_chunks = current[bpp..].chunks_exact(CHUNK_SIZE);
749            let mut cur_minus_bpp_chunks = current[..len - bpp].chunks_exact(CHUNK_SIZE);
750            let mut prev_chunks = previous[bpp..].chunks_exact(CHUNK_SIZE);
751
752            for (((out, cur), cur_minus_bpp), prev) in (&mut out_chunks)
753                .zip(&mut cur_chunks)
754                .zip(&mut cur_minus_bpp_chunks)
755                .zip(&mut prev_chunks)
756            {
757                for i in 0..CHUNK_SIZE {
758                    // Bitwise average of two integers without overflow and
759                    // without converting to a wider bit-width. See:
760                    // http://aggregate.org/MAGIC/#Average%20of%20Integers
761                    // If this is unrolled by component, consider reverting to
762                    // `((cur_minus_bpp[i] as u16 + prev[i] as u16) / 2) as u8`
763                    out[i] = cur[i].wrapping_sub(
764                        (cur_minus_bpp[i] & prev[i]) + ((cur_minus_bpp[i] ^ prev[i]) >> 1),
765                    );
766                }
767            }
768
769            for (((out, cur), &cur_minus_bpp), &prev) in out_chunks
770                .into_remainder()
771                .iter_mut()
772                .zip(cur_chunks.remainder())
773                .zip(cur_minus_bpp_chunks.remainder())
774                .zip(prev_chunks.remainder())
775            {
776                *out = cur.wrapping_sub((cur_minus_bpp & prev) + ((cur_minus_bpp ^ prev) >> 1));
777            }
778
779            for i in 0..bpp {
780                output[i] = current[i].wrapping_sub(previous[i] / 2);
781            }
782            Avg
783        }
784        Paeth => {
785            let mut out_chunks = output[bpp..].chunks_exact_mut(CHUNK_SIZE);
786            let mut cur_chunks = current[bpp..].chunks_exact(CHUNK_SIZE);
787            let mut a_chunks = current[..len - bpp].chunks_exact(CHUNK_SIZE);
788            let mut b_chunks = previous[bpp..].chunks_exact(CHUNK_SIZE);
789            let mut c_chunks = previous[..len - bpp].chunks_exact(CHUNK_SIZE);
790
791            for ((((out, cur), a), b), c) in (&mut out_chunks)
792                .zip(&mut cur_chunks)
793                .zip(&mut a_chunks)
794                .zip(&mut b_chunks)
795                .zip(&mut c_chunks)
796            {
797                for i in 0..CHUNK_SIZE {
798                    out[i] = cur[i].wrapping_sub(filter_paeth_fpnge(a[i], b[i], c[i]));
799                }
800            }
801
802            for ((((out, cur), &a), &b), &c) in out_chunks
803                .into_remainder()
804                .iter_mut()
805                .zip(cur_chunks.remainder())
806                .zip(a_chunks.remainder())
807                .zip(b_chunks.remainder())
808                .zip(c_chunks.remainder())
809            {
810                *out = cur.wrapping_sub(filter_paeth_fpnge(a, b, c));
811            }
812
813            for i in 0..bpp {
814                output[i] = current[i].wrapping_sub(filter_paeth_fpnge(0, previous[i], 0));
815            }
816            Paeth
817        }
818    }
819}
820
821pub(crate) fn filter(
822    method: Filter,
823    bpp: BytesPerPixel,
824    previous: &[u8],
825    current: &[u8],
826    output: &mut [u8],
827) -> RowFilter {
828    use RowFilter::*;
829    let bpp = bpp.into_usize();
830    let len = current.len();
831
832    match method {
833        Filter::Adaptive => {
834            let mut min_sum: u64 = u64::MAX;
835            let mut filter_choice = RowFilter::NoFilter;
836            for &filter in [Sub, Up, Avg, Paeth].iter() {
837                filter_internal(filter, bpp, len, previous, current, output);
838                let sum = sum_buffer(output);
839                if sum <= min_sum {
840                    min_sum = sum;
841                    filter_choice = filter;
842                }
843            }
844
845            if filter_choice != Paeth {
846                filter_internal(filter_choice, bpp, len, previous, current, output);
847            }
848            filter_choice
849        }
850        _ => {
851            let filter = RowFilter::from_method(method).unwrap();
852            filter_internal(filter, bpp, len, previous, current, output)
853        }
854    }
855}
856
857// Helper function for Adaptive filter buffer summation
858fn sum_buffer(buf: &[u8]) -> u64 {
859    const CHUNK_SIZE: usize = 32;
860
861    let mut buf_chunks = buf.chunks_exact(CHUNK_SIZE);
862    let mut sum = 0_u64;
863
864    for chunk in &mut buf_chunks {
865        // At most, `acc` can be `32 * (i8::MIN as u8) = 32 * 128 = 4096`.
866        let mut acc = 0;
867        for &b in chunk {
868            acc += u64::from((b as i8).unsigned_abs());
869        }
870        sum = sum.saturating_add(acc);
871    }
872
873    let mut acc = 0;
874    for &b in buf_chunks.remainder() {
875        acc += u64::from((b as i8).unsigned_abs());
876    }
877
878    sum.saturating_add(acc)
879}
880
881#[cfg(test)]
882mod test {
883    use super::*;
884    use core::iter;
885
886    #[test]
887    fn roundtrip() {
888        // A multiple of 8, 6, 4, 3, 2, 1
889        const LEN: u8 = 240;
890        let previous: Vec<_> = iter::repeat(1).take(LEN.into()).collect();
891        let current: Vec<_> = (0..LEN).collect();
892        let expected = current.clone();
893
894        let roundtrip = |kind: RowFilter, bpp: BytesPerPixel| {
895            let mut output = vec![0; LEN.into()];
896            filter(kind.into(), bpp, &previous, &current, &mut output);
897            unfilter(kind, bpp, &previous, &mut output);
898            assert_eq!(
899                output, expected,
900                "Filtering {:?} with {:?} does not roundtrip",
901                bpp, kind
902            );
903        };
904
905        let filters = [
906            RowFilter::NoFilter,
907            RowFilter::Sub,
908            RowFilter::Up,
909            RowFilter::Avg,
910            RowFilter::Paeth,
911        ];
912
913        let bpps = [
914            BytesPerPixel::One,
915            BytesPerPixel::Two,
916            BytesPerPixel::Three,
917            BytesPerPixel::Four,
918            BytesPerPixel::Six,
919            BytesPerPixel::Eight,
920        ];
921
922        for &filter in filters.iter() {
923            for &bpp in bpps.iter() {
924                roundtrip(filter, bpp);
925            }
926        }
927    }
928
929    #[test]
930    #[ignore] // takes ~20s without optimizations
931    fn paeth_impls_are_equivalent() {
932        for a in 0..=255 {
933            for b in 0..=255 {
934                for c in 0..=255 {
935                    let baseline = filter_paeth(a, b, c);
936                    let fpnge = filter_paeth_fpnge(a, b, c);
937                    let stbi = filter_paeth_stbi(a, b, c);
938
939                    assert_eq!(baseline, fpnge);
940                    assert_eq!(baseline, stbi);
941                }
942            }
943        }
944    }
945
946    #[test]
947    fn roundtrip_ascending_previous_line() {
948        // A multiple of 8, 6, 4, 3, 2, 1
949        const LEN: u8 = 240;
950        let previous: Vec<_> = (0..LEN).collect();
951        let current: Vec<_> = (0..LEN).collect();
952        let expected = current.clone();
953
954        let roundtrip = |kind: RowFilter, bpp: BytesPerPixel| {
955            let mut output = vec![0; LEN.into()];
956            filter(kind.into(), bpp, &previous, &current, &mut output);
957            unfilter(kind, bpp, &previous, &mut output);
958            assert_eq!(
959                output, expected,
960                "Filtering {:?} with {:?} does not roundtrip",
961                bpp, kind
962            );
963        };
964
965        let filters = [
966            RowFilter::NoFilter,
967            RowFilter::Sub,
968            RowFilter::Up,
969            RowFilter::Avg,
970            RowFilter::Paeth,
971        ];
972
973        let bpps = [
974            BytesPerPixel::One,
975            BytesPerPixel::Two,
976            BytesPerPixel::Three,
977            BytesPerPixel::Four,
978            BytesPerPixel::Six,
979            BytesPerPixel::Eight,
980        ];
981
982        for &filter in filters.iter() {
983            for &bpp in bpps.iter() {
984                roundtrip(filter, bpp);
985            }
986        }
987    }
988
989    #[test]
990    // This tests that converting u8 to i8 doesn't overflow when taking the
991    // absolute value for adaptive filtering: -128_i8.abs() will panic in debug
992    // or produce garbage in release mode. The sum of 0..=255u8 should equal the
993    // sum of the absolute values of -128_i8..=127, or abs(-128..=0) + 1..=127.
994    fn sum_buffer_test() {
995        let sum = (0..=128).sum::<u64>() + (1..=127).sum::<u64>();
996        let buf: Vec<u8> = (0_u8..=255).collect();
997
998        assert_eq!(sum, crate::filter::sum_buffer(&buf));
999    }
1000}