palette/
matrix.rs

1//! This module provides simple matrix operations on 3x3 matrices to aid in
2//! chromatic adaptation and conversion calculations.
3
4use core::marker::PhantomData;
5
6use crate::{
7    convert::IntoColorUnclamped,
8    encoding::Linear,
9    num::{Arithmetics, FromScalar, IsValidDivisor, Recip},
10    rgb::{Primaries, Rgb, RgbSpace},
11    white_point::{Any, WhitePoint},
12    Xyz, Yxy,
13};
14
15/// A 9 element array representing a 3x3 matrix.
16pub type Mat3<T> = [T; 9];
17
18/// Multiply the 3x3 matrix with an XYZ color.
19#[inline]
20pub fn multiply_xyz<T>(c: Mat3<T>, f: Xyz<Any, T>) -> Xyz<Any, T>
21where
22    T: Arithmetics,
23{
24    // Input Mat3 is destructured to avoid panic paths
25    let [c0, c1, c2, c3, c4, c5, c6, c7, c8] = c;
26
27    let x1 = c0 * &f.x;
28    let y1 = c3 * &f.x;
29    let z1 = c6 * f.x;
30    let x2 = c1 * &f.y;
31    let y2 = c4 * &f.y;
32    let z2 = c7 * f.y;
33    let x3 = c2 * &f.z;
34    let y3 = c5 * &f.z;
35    let z3 = c8 * f.z;
36
37    Xyz {
38        x: x1 + x2 + x3,
39        y: y1 + y2 + y3,
40        z: z1 + z2 + z3,
41        white_point: PhantomData,
42    }
43}
44/// Multiply the 3x3 matrix with an XYZ color to return an RGB color.
45#[inline]
46pub fn multiply_xyz_to_rgb<S, V, T>(c: Mat3<T>, f: Xyz<S::WhitePoint, V>) -> Rgb<Linear<S>, V>
47where
48    S: RgbSpace,
49    V: Arithmetics + FromScalar<Scalar = T>,
50{
51    // Input Mat3 is destructured to avoid panic paths. red, green, and blue
52    // can't be extracted like in `multiply_xyz` to get a performance increase
53    let [c0, c1, c2, c3, c4, c5, c6, c7, c8] = c;
54
55    Rgb {
56        red: (V::from_scalar(c0) * &f.x)
57            + (V::from_scalar(c1) * &f.y)
58            + (V::from_scalar(c2) * &f.z),
59        green: (V::from_scalar(c3) * &f.x)
60            + (V::from_scalar(c4) * &f.y)
61            + (V::from_scalar(c5) * &f.z),
62        blue: (V::from_scalar(c6) * f.x) + (V::from_scalar(c7) * f.y) + (V::from_scalar(c8) * f.z),
63        standard: PhantomData,
64    }
65}
66/// Multiply the 3x3 matrix with an RGB color to return an XYZ color.
67#[inline]
68pub fn multiply_rgb_to_xyz<S, V, T>(c: Mat3<T>, f: Rgb<Linear<S>, V>) -> Xyz<S::WhitePoint, V>
69where
70    S: RgbSpace,
71    V: Arithmetics + FromScalar<Scalar = T>,
72{
73    // Input Mat3 is destructured to avoid panic paths. Same problem as
74    // `multiply_xyz_to_rgb` for extracting x, y, z
75    let [c0, c1, c2, c3, c4, c5, c6, c7, c8] = c;
76
77    Xyz {
78        x: (V::from_scalar(c0) * &f.red)
79            + (V::from_scalar(c1) * &f.green)
80            + (V::from_scalar(c2) * &f.blue),
81        y: (V::from_scalar(c3) * &f.red)
82            + (V::from_scalar(c4) * &f.green)
83            + (V::from_scalar(c5) * &f.blue),
84        z: (V::from_scalar(c6) * f.red)
85            + (V::from_scalar(c7) * f.green)
86            + (V::from_scalar(c8) * f.blue),
87        white_point: PhantomData,
88    }
89}
90
91/// Multiply two 3x3 matrices.
92#[inline]
93pub fn multiply_3x3<T>(c: Mat3<T>, f: Mat3<T>) -> Mat3<T>
94where
95    T: Arithmetics + Clone,
96{
97    // Input Mat3 are destructured to avoid panic paths
98    let [c0, c1, c2, c3, c4, c5, c6, c7, c8] = c;
99    let [f0, f1, f2, f3, f4, f5, f6, f7, f8] = f;
100
101    let o0 = c0.clone() * &f0 + c1.clone() * &f3 + c2.clone() * &f6;
102    let o1 = c0.clone() * &f1 + c1.clone() * &f4 + c2.clone() * &f7;
103    let o2 = c0 * &f2 + c1 * &f5 + c2 * &f8;
104
105    let o3 = c3.clone() * &f0 + c4.clone() * &f3 + c5.clone() * &f6;
106    let o4 = c3.clone() * &f1 + c4.clone() * &f4 + c5.clone() * &f7;
107    let o5 = c3 * &f2 + c4 * &f5 + c5 * &f8;
108
109    let o6 = c6.clone() * f0 + c7.clone() * f3 + c8.clone() * f6;
110    let o7 = c6.clone() * f1 + c7.clone() * f4 + c8.clone() * f7;
111    let o8 = c6 * f2 + c7 * f5 + c8 * f8;
112
113    [o0, o1, o2, o3, o4, o5, o6, o7, o8]
114}
115
116/// Invert a 3x3 matrix and panic if matrix is not invertible.
117#[inline]
118pub fn matrix_inverse<T>(a: Mat3<T>) -> Mat3<T>
119where
120    T: Recip + IsValidDivisor<Mask = bool> + Arithmetics + Clone,
121{
122    // This function runs fastest with assert and no destructuring. The `det`'s
123    // location should not be changed until benched that it's faster elsewhere
124    assert!(a.len() > 8);
125
126    let d0 = a[4].clone() * &a[8] - a[5].clone() * &a[7];
127    let d1 = a[3].clone() * &a[8] - a[5].clone() * &a[6];
128    let d2 = a[3].clone() * &a[7] - a[4].clone() * &a[6];
129    let mut det = a[0].clone() * &d0 - a[1].clone() * &d1 + a[2].clone() * &d2;
130    let d3 = a[1].clone() * &a[8] - a[2].clone() * &a[7];
131    let d4 = a[0].clone() * &a[8] - a[2].clone() * &a[6];
132    let d5 = a[0].clone() * &a[7] - a[1].clone() * &a[6];
133    let d6 = a[1].clone() * &a[5] - a[2].clone() * &a[4];
134    let d7 = a[0].clone() * &a[5] - a[2].clone() * &a[3];
135    let d8 = a[0].clone() * &a[4] - a[1].clone() * &a[3];
136
137    if !det.is_valid_divisor() {
138        panic!("The given matrix is not invertible")
139    }
140    det = det.recip();
141
142    [
143        d0 * &det,
144        -d3 * &det,
145        d6 * &det,
146        -d1 * &det,
147        d4 * &det,
148        -d7 * &det,
149        d2 * &det,
150        -d5 * &det,
151        d8 * det,
152    ]
153}
154
155/// Maps a matrix from one item type to another.
156///
157/// This turned out to be easier for the compiler to optimize than `matrix.map(f)`.
158#[inline(always)]
159pub fn matrix_map<T, U>(matrix: Mat3<T>, mut f: impl FnMut(T) -> U) -> Mat3<U> {
160    let [m1, m2, m3, m4, m5, m6, m7, m8, m9] = matrix;
161    [
162        f(m1),
163        f(m2),
164        f(m3),
165        f(m4),
166        f(m5),
167        f(m6),
168        f(m7),
169        f(m8),
170        f(m9),
171    ]
172}
173
174/// Generates the Srgb to Xyz transformation matrix for a given white point.
175#[inline]
176pub fn rgb_to_xyz_matrix<S, T>() -> Mat3<T>
177where
178    S: RgbSpace,
179    S::Primaries: Primaries<T>,
180    S::WhitePoint: WhitePoint<T>,
181    T: Recip + IsValidDivisor<Mask = bool> + Arithmetics + Clone + FromScalar<Scalar = T>,
182    Yxy<Any, T>: IntoColorUnclamped<Xyz<Any, T>>,
183{
184    let r = S::Primaries::red().into_color_unclamped();
185    let g = S::Primaries::green().into_color_unclamped();
186    let b = S::Primaries::blue().into_color_unclamped();
187
188    let matrix = mat3_from_primaries(r, g, b);
189
190    let s_matrix: Rgb<Linear<S>, T> = multiply_xyz_to_rgb(
191        matrix_inverse(matrix.clone()),
192        S::WhitePoint::get_xyz().with_white_point(),
193    );
194
195    // Destructuring has some performance benefits, don't change unless measured
196    let [t0, t1, t2, t3, t4, t5, t6, t7, t8] = matrix;
197
198    [
199        t0 * &s_matrix.red,
200        t1 * &s_matrix.green,
201        t2 * &s_matrix.blue,
202        t3 * &s_matrix.red,
203        t4 * &s_matrix.green,
204        t5 * &s_matrix.blue,
205        t6 * s_matrix.red,
206        t7 * s_matrix.green,
207        t8 * s_matrix.blue,
208    ]
209}
210
211#[rustfmt::skip]
212#[inline]
213fn mat3_from_primaries<T>(r: Xyz<Any, T>, g: Xyz<Any, T>, b: Xyz<Any, T>) -> Mat3<T> {
214    [
215        r.x, g.x, b.x,
216        r.y, g.y, b.y,
217        r.z, g.z, b.z,
218    ]
219}
220
221#[cfg(feature = "approx")]
222#[cfg(test)]
223mod test {
224    use super::{matrix_inverse, multiply_3x3, multiply_xyz, rgb_to_xyz_matrix};
225    use crate::chromatic_adaptation::AdaptInto;
226    use crate::encoding::{Linear, Srgb};
227    use crate::rgb::Rgb;
228    use crate::white_point::D50;
229    use crate::Xyz;
230
231    #[test]
232    fn matrix_multiply_3x3() {
233        let inp1 = [1.0, 2.0, 3.0, 3.0, 2.0, 1.0, 2.0, 1.0, 3.0];
234        let inp2 = [4.0, 5.0, 6.0, 6.0, 5.0, 4.0, 4.0, 6.0, 5.0];
235        let expected = [28.0, 33.0, 29.0, 28.0, 31.0, 31.0, 26.0, 33.0, 31.0];
236
237        let computed = multiply_3x3(inp1, inp2);
238        for (t1, t2) in expected.iter().zip(computed.iter()) {
239            assert_relative_eq!(t1, t2);
240        }
241    }
242
243    #[test]
244    fn matrix_multiply_xyz() {
245        let inp1 = [0.1, 0.2, 0.3, 0.3, 0.2, 0.1, 0.2, 0.1, 0.3];
246        let inp2 = Xyz::new(0.4, 0.6, 0.8);
247
248        let expected = Xyz::new(0.4, 0.32, 0.38);
249
250        let computed = multiply_xyz(inp1, inp2);
251        assert_relative_eq!(expected, computed)
252    }
253
254    #[test]
255    fn matrix_inverse_check_1() {
256        let input: [f64; 9] = [3.0, 0.0, 2.0, 2.0, 0.0, -2.0, 0.0, 1.0, 1.0];
257
258        let expected: [f64; 9] = [0.2, 0.2, 0.0, -0.2, 0.3, 1.0, 0.2, -0.3, 0.0];
259        let computed = matrix_inverse(input);
260        for (t1, t2) in expected.iter().zip(computed.iter()) {
261            assert_relative_eq!(t1, t2);
262        }
263    }
264    #[test]
265    fn matrix_inverse_check_2() {
266        let input: [f64; 9] = [1.0, 0.0, 1.0, 0.0, 2.0, 1.0, 1.0, 1.0, 1.0];
267
268        let expected: [f64; 9] = [-1.0, -1.0, 2.0, -1.0, 0.0, 1.0, 2.0, 1.0, -2.0];
269        let computed = matrix_inverse(input);
270        for (t1, t2) in expected.iter().zip(computed.iter()) {
271            assert_relative_eq!(t1, t2);
272        }
273    }
274    #[test]
275    #[should_panic]
276    fn matrix_inverse_panic() {
277        let input: [f64; 9] = [1.0, 0.0, 0.0, 2.0, 0.0, 0.0, -4.0, 6.0, 1.0];
278        matrix_inverse(input);
279    }
280
281    #[rustfmt::skip]
282    #[test]
283    fn d65_rgb_conversion_matrix() {
284        let expected = [
285            0.4124564, 0.3575761, 0.1804375,
286            0.2126729, 0.7151522, 0.0721750,
287            0.0193339, 0.1191920, 0.9503041
288        ];
289        let computed = rgb_to_xyz_matrix::<Srgb, f64>();
290        for (e, c) in expected.iter().zip(computed.iter()) {
291            assert_relative_eq!(e, c, epsilon = 0.000001)
292        }
293    }
294
295    #[test]
296    fn d65_to_d50() {
297        let input: Rgb<Linear<Srgb>> = Rgb::new(1.0, 1.0, 1.0);
298        let expected: Rgb<Linear<(Srgb, D50)>> = Rgb::new(1.0, 1.0, 1.0);
299
300        let computed: Rgb<Linear<(Srgb, D50)>> = input.adapt_into();
301        assert_relative_eq!(expected, computed, epsilon = 0.000001);
302    }
303}